From 015612d5903888dc73db7ecbb5983438d4627ecd Mon Sep 17 00:00:00 2001
From: ancarola <raffaele.ancarola@epfl.ch>
Date: Mon, 1 Jul 2019 15:58:29 +0200
Subject: Small correction on basic multiplication

---
 include/mm/mmiterator.hpp | 84 ++++++++++++++++++++++++++++++-----------------
 include/mm/mmmatrix.hpp   |  4 ++-
 2 files changed, 56 insertions(+), 32 deletions(-)

(limited to 'include')

diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp
index 7b3480e..ed406fe 100644
--- a/include/mm/mmiterator.hpp
+++ b/include/mm/mmiterator.hpp
@@ -28,10 +28,13 @@ public:
     vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0)
         : M(_M), position(pos), index(i) {}
 
+#ifdef MM_IMPLICIT_CONVERSION_ITERATOR
     operator T&()
     {
+        npdebug("Calling +")
         return *(*this);
     }
+#endif
 
     IterType operator++()
     {
@@ -69,11 +72,18 @@ public:
         return index != other.index;
     }
 
-    virtual bool ok() const = 0;
+    bool ok() const
+    {
+        return index < size();
+    }
+
+    virtual std::size_t size() const = 0;
     
     virtual T& operator*() = 0;
     virtual T& operator[](std::size_t) = 0;
 
+    virtual T& operator[](std::size_t) const = 0;
+
     IterType begin()
     {
         return IterType(M, position, 0);
@@ -81,32 +91,6 @@ public:
 
     virtual IterType end() = 0;
     
-    /*
-     * Scalar product 
-     */
-
-    template<std::size_t P>
-    T operator*(const mm::iter::vector_iterator<T, Rows, P, IterType, Grid>& v)
-    {
-        T out(0);
-
-        for (unsigned k(0); k < Rows; ++k)
-            out += (*this)[k] * v[k];
-
-        return out;
-    }
-
-    template<std::size_t P>
-    T operator*(const mm::iter::vector_iterator<T, P, Cols, IterType, Grid>& v)
-    {
-        T out(0);
-
-        for (unsigned k(0); k < Cols; ++k)
-            out += (*this)[k] * v[k];
-
-        return out;
-    }
-
 protected:
 
     Grid& M; // grid mapping
@@ -118,6 +102,28 @@ protected:
     virtual IterType cpy() = 0;
 };
 
+
+/*
+ * Scalar product 
+ */
+
+template<typename T, 
+    std::size_t R1, std::size_t C1, 
+    std::size_t R2, std::size_t C2, 
+    class IterType1, class IterType2,
+    class Grid1, class Grid2>
+typename std::remove_const<T>::type operator*(const mm::iter::vector_iterator<T, R1, C1, IterType1, Grid1>& v,
+            const mm::iter::vector_iterator<T, R2, C2, IterType2, Grid2>& w)
+{
+    typename std::remove_const<T>::type out(0);
+    const std::size_t N = std::min(v.size(), w.size());
+
+    for(unsigned i = 0; i < N; ++i)
+        out += v[i] * w[i];
+
+    return out;
+}
+
 template<typename T, std::size_t Rows, std::size_t Cols, class Grid>
 class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid>
 {
@@ -147,11 +153,12 @@ public:
            assert(pos < Cols);
     }
 
-    virtual bool ok() const override
+    virtual std::size_t size() const
     {
-        return (direction) ? this->index < Cols : this->index < Rows;
+        return (direction) ? Cols : Rows;
     }
 
+    
     virtual T& operator*() override
     {
         return (direction) ?
@@ -167,6 +174,13 @@ public:
             this->M.data[i * Cols + this->position];
     }
 
+    virtual T& operator[](std::size_t i) const override
+    {
+        return (direction) ?
+            this->M.data[this->position * Cols + i] :
+            this->M.data[i * Cols + this->position];
+    }
+
     virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> end()
     {
         return mm::iter::basic_iterator<T, Rows, Cols, Grid>(this->M, this->position, 
@@ -198,9 +212,9 @@ public:
         assert(this->position < N);
     }
 
-    virtual bool ok() const 
+    virtual std::size_t size() const
     {
-        return this->index < N;
+        return N - this->position;
     }
 
     virtual T& operator*() override
@@ -217,6 +231,14 @@ public:
             this->M.data[i * N + (i + this->position)];
     }
 
+    virtual T& operator[](std::size_t i) const override
+    {
+        return (sign) ?
+            this->M.data[(i - this->position) * N + i] :
+            this->M.data[i * N + (i + this->position)];
+    }
+
+
     virtual mm::iter::diag_iterator<T, N, Grid> end()
     {
         return mm::iter::diag_iterator<T, N, Grid>(this->M, this->position, N);
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index 5eb0171..0b7a40b 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -414,7 +414,9 @@ mm::matrix<T, M, N> operator*(
     assert(a.cols() == b.rows());
 
     mm::matrix<T, M, N> result;
-    mm::matrix<T, P2, N> bt = b.t(); // weak transposition
+    const mm::matrix<T, P2, N> bt = b.t(); // weak transposition
+
+    npdebug("Calling *")
 
     for (unsigned row = 0; row < M; row++)
         for (unsigned col = 0; col < N; col++)
-- 
cgit v1.2.1