summaryrefslogtreecommitdiffstats
path: root/include/mm
diff options
context:
space:
mode:
Diffstat (limited to 'include/mm')
-rw-r--r--include/mm/mmiterator.hpp84
-rw-r--r--include/mm/mmmatrix.hpp4
2 files changed, 56 insertions, 32 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp
index 7b3480e..ed406fe 100644
--- a/include/mm/mmiterator.hpp
+++ b/include/mm/mmiterator.hpp
@@ -28,10 +28,13 @@ public:
vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0)
: M(_M), position(pos), index(i) {}
+#ifdef MM_IMPLICIT_CONVERSION_ITERATOR
operator T&()
{
+ npdebug("Calling +")
return *(*this);
}
+#endif
IterType operator++()
{
@@ -69,11 +72,18 @@ public:
return index != other.index;
}
- virtual bool ok() const = 0;
+ bool ok() const
+ {
+ return index < size();
+ }
+
+ virtual std::size_t size() const = 0;
virtual T& operator*() = 0;
virtual T& operator[](std::size_t) = 0;
+ virtual T& operator[](std::size_t) const = 0;
+
IterType begin()
{
return IterType(M, position, 0);
@@ -81,32 +91,6 @@ public:
virtual IterType end() = 0;
- /*
- * Scalar product
- */
-
- template<std::size_t P>
- T operator*(const mm::iter::vector_iterator<T, Rows, P, IterType, Grid>& v)
- {
- T out(0);
-
- for (unsigned k(0); k < Rows; ++k)
- out += (*this)[k] * v[k];
-
- return out;
- }
-
- template<std::size_t P>
- T operator*(const mm::iter::vector_iterator<T, P, Cols, IterType, Grid>& v)
- {
- T out(0);
-
- for (unsigned k(0); k < Cols; ++k)
- out += (*this)[k] * v[k];
-
- return out;
- }
-
protected:
Grid& M; // grid mapping
@@ -118,6 +102,28 @@ protected:
virtual IterType cpy() = 0;
};
+
+/*
+ * Scalar product
+ */
+
+template<typename T,
+ std::size_t R1, std::size_t C1,
+ std::size_t R2, std::size_t C2,
+ class IterType1, class IterType2,
+ class Grid1, class Grid2>
+typename std::remove_const<T>::type operator*(const mm::iter::vector_iterator<T, R1, C1, IterType1, Grid1>& v,
+ const mm::iter::vector_iterator<T, R2, C2, IterType2, Grid2>& w)
+{
+ typename std::remove_const<T>::type out(0);
+ const std::size_t N = std::min(v.size(), w.size());
+
+ for(unsigned i = 0; i < N; ++i)
+ out += v[i] * w[i];
+
+ return out;
+}
+
template<typename T, std::size_t Rows, std::size_t Cols, class Grid>
class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid>
{
@@ -147,11 +153,12 @@ public:
assert(pos < Cols);
}
- virtual bool ok() const override
+ virtual std::size_t size() const
{
- return (direction) ? this->index < Cols : this->index < Rows;
+ return (direction) ? Cols : Rows;
}
+
virtual T& operator*() override
{
return (direction) ?
@@ -167,6 +174,13 @@ public:
this->M.data[i * Cols + this->position];
}
+ virtual T& operator[](std::size_t i) const override
+ {
+ return (direction) ?
+ this->M.data[this->position * Cols + i] :
+ this->M.data[i * Cols + this->position];
+ }
+
virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> end()
{
return mm::iter::basic_iterator<T, Rows, Cols, Grid>(this->M, this->position,
@@ -198,9 +212,9 @@ public:
assert(this->position < N);
}
- virtual bool ok() const
+ virtual std::size_t size() const
{
- return this->index < N;
+ return N - this->position;
}
virtual T& operator*() override
@@ -217,6 +231,14 @@ public:
this->M.data[i * N + (i + this->position)];
}
+ virtual T& operator[](std::size_t i) const override
+ {
+ return (sign) ?
+ this->M.data[(i - this->position) * N + i] :
+ this->M.data[i * N + (i + this->position)];
+ }
+
+
virtual mm::iter::diag_iterator<T, N, Grid> end()
{
return mm::iter::diag_iterator<T, N, Grid>(this->M, this->position, N);
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index 5eb0171..0b7a40b 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -414,7 +414,9 @@ mm::matrix<T, M, N> operator*(
assert(a.cols() == b.rows());
mm::matrix<T, M, N> result;
- mm::matrix<T, P2, N> bt = b.t(); // weak transposition
+ const mm::matrix<T, P2, N> bt = b.t(); // weak transposition
+
+ npdebug("Calling *")
for (unsigned row = 0; row < M; row++)
for (unsigned col = 0; col < N; col++)