diff options
author | ancarola <raffaele.ancarola@epfl.ch> | 2019-07-01 15:58:29 +0200 |
---|---|---|
committer | ancarola <raffaele.ancarola@epfl.ch> | 2019-07-01 15:58:29 +0200 |
commit | 015612d5903888dc73db7ecbb5983438d4627ecd (patch) | |
tree | da4105b20939324c35b2b9d7af6c564019cd4ca3 | |
parent | The matrix library is compiling and all tested operations work fine. (diff) | |
download | libmm-015612d5903888dc73db7ecbb5983438d4627ecd.tar.gz libmm-015612d5903888dc73db7ecbb5983438d4627ecd.zip |
Small correction on basic multiplication
-rw-r--r-- | include/mm/mmiterator.hpp | 84 | ||||
-rw-r--r-- | include/mm/mmmatrix.hpp | 4 |
2 files changed, 56 insertions, 32 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index 7b3480e..ed406fe 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -28,10 +28,13 @@ public: vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) : M(_M), position(pos), index(i) {} +#ifdef MM_IMPLICIT_CONVERSION_ITERATOR operator T&() { + npdebug("Calling +") return *(*this); } +#endif IterType operator++() { @@ -69,11 +72,18 @@ public: return index != other.index; } - virtual bool ok() const = 0; + bool ok() const + { + return index < size(); + } + + virtual std::size_t size() const = 0; virtual T& operator*() = 0; virtual T& operator[](std::size_t) = 0; + virtual T& operator[](std::size_t) const = 0; + IterType begin() { return IterType(M, position, 0); @@ -81,32 +91,6 @@ public: virtual IterType end() = 0; - /* - * Scalar product - */ - - template<std::size_t P> - T operator*(const mm::iter::vector_iterator<T, Rows, P, IterType, Grid>& v) - { - T out(0); - - for (unsigned k(0); k < Rows; ++k) - out += (*this)[k] * v[k]; - - return out; - } - - template<std::size_t P> - T operator*(const mm::iter::vector_iterator<T, P, Cols, IterType, Grid>& v) - { - T out(0); - - for (unsigned k(0); k < Cols; ++k) - out += (*this)[k] * v[k]; - - return out; - } - protected: Grid& M; // grid mapping @@ -118,6 +102,28 @@ protected: virtual IterType cpy() = 0; }; + +/* + * Scalar product + */ + +template<typename T, + std::size_t R1, std::size_t C1, + std::size_t R2, std::size_t C2, + class IterType1, class IterType2, + class Grid1, class Grid2> +typename std::remove_const<T>::type operator*(const mm::iter::vector_iterator<T, R1, C1, IterType1, Grid1>& v, + const mm::iter::vector_iterator<T, R2, C2, IterType2, Grid2>& w) +{ + typename std::remove_const<T>::type out(0); + const std::size_t N = std::min(v.size(), w.size()); + + for(unsigned i = 0; i < N; ++i) + out += v[i] * w[i]; + + return out; +} + template<typename T, std::size_t Rows, std::size_t Cols, class Grid> class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid> { @@ -147,11 +153,12 @@ public: assert(pos < Cols); } - virtual bool ok() const override + virtual std::size_t size() const { - return (direction) ? this->index < Cols : this->index < Rows; + return (direction) ? Cols : Rows; } + virtual T& operator*() override { return (direction) ? @@ -167,6 +174,13 @@ public: this->M.data[i * Cols + this->position]; } + virtual T& operator[](std::size_t i) const override + { + return (direction) ? + this->M.data[this->position * Cols + i] : + this->M.data[i * Cols + this->position]; + } + virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> end() { return mm::iter::basic_iterator<T, Rows, Cols, Grid>(this->M, this->position, @@ -198,9 +212,9 @@ public: assert(this->position < N); } - virtual bool ok() const + virtual std::size_t size() const { - return this->index < N; + return N - this->position; } virtual T& operator*() override @@ -217,6 +231,14 @@ public: this->M.data[i * N + (i + this->position)]; } + virtual T& operator[](std::size_t i) const override + { + return (sign) ? + this->M.data[(i - this->position) * N + i] : + this->M.data[i * N + (i + this->position)]; + } + + virtual mm::iter::diag_iterator<T, N, Grid> end() { return mm::iter::diag_iterator<T, N, Grid>(this->M, this->position, N); diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 5eb0171..0b7a40b 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -414,7 +414,9 @@ mm::matrix<T, M, N> operator*( assert(a.cols() == b.rows()); mm::matrix<T, M, N> result; - mm::matrix<T, P2, N> bt = b.t(); // weak transposition + const mm::matrix<T, P2, N> bt = b.t(); // weak transposition + + npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) |