From 015612d5903888dc73db7ecbb5983438d4627ecd Mon Sep 17 00:00:00 2001 From: ancarola Date: Mon, 1 Jul 2019 15:58:29 +0200 Subject: Small correction on basic multiplication --- include/mm/mmiterator.hpp | 84 ++++++++++++++++++++++++++++++----------------- include/mm/mmmatrix.hpp | 4 ++- 2 files changed, 56 insertions(+), 32 deletions(-) diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index 7b3480e..ed406fe 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -28,10 +28,13 @@ public: vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) : M(_M), position(pos), index(i) {} +#ifdef MM_IMPLICIT_CONVERSION_ITERATOR operator T&() { + npdebug("Calling +") return *(*this); } +#endif IterType operator++() { @@ -69,11 +72,18 @@ public: return index != other.index; } - virtual bool ok() const = 0; + bool ok() const + { + return index < size(); + } + + virtual std::size_t size() const = 0; virtual T& operator*() = 0; virtual T& operator[](std::size_t) = 0; + virtual T& operator[](std::size_t) const = 0; + IterType begin() { return IterType(M, position, 0); @@ -81,32 +91,6 @@ public: virtual IterType end() = 0; - /* - * Scalar product - */ - - template - T operator*(const mm::iter::vector_iterator& v) - { - T out(0); - - for (unsigned k(0); k < Rows; ++k) - out += (*this)[k] * v[k]; - - return out; - } - - template - T operator*(const mm::iter::vector_iterator& v) - { - T out(0); - - for (unsigned k(0); k < Cols; ++k) - out += (*this)[k] * v[k]; - - return out; - } - protected: Grid& M; // grid mapping @@ -118,6 +102,28 @@ protected: virtual IterType cpy() = 0; }; + +/* + * Scalar product + */ + +template +typename std::remove_const::type operator*(const mm::iter::vector_iterator& v, + const mm::iter::vector_iterator& w) +{ + typename std::remove_const::type out(0); + const std::size_t N = std::min(v.size(), w.size()); + + for(unsigned i = 0; i < N; ++i) + out += v[i] * w[i]; + + return out; +} + template class mm::iter::basic_iterator : public mm::iter::vector_iterator, Grid> { @@ -147,11 +153,12 @@ public: assert(pos < Cols); } - virtual bool ok() const override + virtual std::size_t size() const { - return (direction) ? this->index < Cols : this->index < Rows; + return (direction) ? Cols : Rows; } + virtual T& operator*() override { return (direction) ? @@ -167,6 +174,13 @@ public: this->M.data[i * Cols + this->position]; } + virtual T& operator[](std::size_t i) const override + { + return (direction) ? + this->M.data[this->position * Cols + i] : + this->M.data[i * Cols + this->position]; + } + virtual mm::iter::basic_iterator end() { return mm::iter::basic_iterator(this->M, this->position, @@ -198,9 +212,9 @@ public: assert(this->position < N); } - virtual bool ok() const + virtual std::size_t size() const { - return this->index < N; + return N - this->position; } virtual T& operator*() override @@ -217,6 +231,14 @@ public: this->M.data[i * N + (i + this->position)]; } + virtual T& operator[](std::size_t i) const override + { + return (sign) ? + this->M.data[(i - this->position) * N + i] : + this->M.data[i * N + (i + this->position)]; + } + + virtual mm::iter::diag_iterator end() { return mm::iter::diag_iterator(this->M, this->position, N); diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 5eb0171..0b7a40b 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -414,7 +414,9 @@ mm::matrix operator*( assert(a.cols() == b.rows()); mm::matrix result; - mm::matrix bt = b.t(); // weak transposition + const mm::matrix bt = b.t(); // weak transposition + + npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) -- cgit v1.2.1