From eb5cdb04efc9984d0937b65620606f8043dd1831 Mon Sep 17 00:00:00 2001 From: ancarola Date: Wed, 10 Jul 2019 10:24:37 +0200 Subject: Implicit convertion to basic_vec to vec2 or vec3 --- include/mm/mmmatrix.hpp | 198 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 142 insertions(+), 56 deletions(-) (limited to 'include/mm/mmmatrix.hpp') diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 0b7a40b..facbb48 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -23,41 +23,30 @@ namespace mm { + /* basic grid structure */ + template class basic_matrix; - /* specialisations */ + /* basic wrapper */ template class matrix; // simple matrix format + + /* specialisations */ - /* specialization of basic_matrx for Cols = 1 */ - template - class row_vec; - - /* specialization of basic_matrx for Rows = 1 */ - template - class col_vec; // transposed version of row_vec + /* specialization of a matrix */ + template + class vector; // by default, set a column vector template class square_matrix; /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template - class diag_matrix; + template + class multi_diag_matrix; } - -/*namespace mm { - - template - using diag_iterator = vector_iterator>; - - template - using const_diag_iterator = vector_iterator::type, N, N, MM_DIAG_ITER, typename std::add_const>::type>; -}*/ - - /* * Matrix class, no access methods */ @@ -211,11 +200,6 @@ public: other.M = nullptr; } - // copy from another matrix - /*template - matrix(const matrix& other) - : M(std::make_shared(*other.M)), transposed(other.transposed) {} */ - matrix operator=(const basic_matrix& other) // deep copy { *M = *other.M; @@ -232,12 +216,9 @@ public: return *this; } - matrix transpose() const + const matrix transpose() const { - auto m = shallow_cpy(); - m.transposed = !transposed; - - return m; + return matrix(M, !transposed); } inline matrix& td() @@ -277,28 +258,24 @@ public: return static_cast>(*this); } - /// downcast to row_vector - static inline constexpr bool is_row_vec() { return (Cols == 1); } - inline constexpr row_vec to_row_vec() const { - static_assert(is_row_vec()); - return static_cast>(*this); - } - /// downcast to col_vector - static inline constexpr bool is_col_vec() { return (Rows == 1); } - inline constexpr col_vec to_col_vec() const { - static_assert(is_col_vec()); - return static_cast>(*this); + static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } + inline vector to_vector() const { + if constexpr(Cols == 1) + return static_cast>(*this); + else if (Rows == 1) + return vector(*this); // copy into column vector + } /* Accessors */ - T& at(std::size_t row, std::size_t col) + virtual T& at(std::size_t row, std::size_t col) { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } - const T& at(std::size_t row, std::size_t col) const + virtual const T& at(std::size_t row, std::size_t col) const { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } @@ -311,12 +288,12 @@ public: return (transposed) ? Rows : Cols; } - mm::matrix::vec_iterator operator[](std::size_t index) + virtual mm::matrix::vec_iterator operator[](std::size_t index) { return mm::matrix::vec_iterator(*M, index, 0, !transposed); } - mm::matrix::const_vec_iterator operator[](std::size_t index) const + virtual mm::matrix::const_vec_iterator operator[](std::size_t index) const { return mm::matrix::const_vec_iterator(*M, index, 0, !transposed); } @@ -358,12 +335,7 @@ protected: // shallow construction matrix(std::shared_ptr> grid, bool tr = false) : M(grid), transposed(tr) {} - - matrix shallow_cpy() const - { - return matrix(M, transposed); - } - + private: bool transposed; @@ -416,7 +388,7 @@ mm::matrix operator*( mm::matrix result; const mm::matrix bt = b.t(); // weak transposition - npdebug("Calling *") + //npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) @@ -443,6 +415,33 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix& m) { return os; } +/* + * Vector, TODO better manage column and row + */ + +template +class mm::vector : public mm::matrix +{ +public: + + using mm::matrix::matrix; + + vector(std::initializer_list l) + : mm::matrix(l) {} +}; + +template +mm::vector operator^(const mm::vector& v, const mm::vector& w) +{ + mm::vector out; + + out[0] = v[1] * w[2] - v[2] * w[2]; + out[1] = v[2] * w[0] - v[0] * w[2]; + out[2] = v[0] * w[1] - v[1] * w[0]; + + return out; +} + /* * Square matrix */ @@ -458,15 +457,15 @@ public: using const_diag_iterator = mm::iter::diag_iterator::type, N, typename std::add_const>::type>; - T trace(); + virtual T trace(); inline T tr() { return trace(); } - mm::square_matrix::diag_iterator diag_beg(int row = 0) + virtual mm::square_matrix::diag_iterator diag_beg(int row = 0) { return diag_iterator(*(this->M), row, 0); } - mm::square_matrix::const_diag_iterator diag_end(int row = 0) const + virtual mm::square_matrix::const_diag_iterator diag_end(int row = 0) const { return const_diag_iterator(*(this->M), row, N); } @@ -502,4 +501,91 @@ T mm::square_matrix::trace() return sum; } +// TODO, static assert, for all: Diags > -N, Diags < N +// TODO, force Diags to be ordered +template +class mm::multi_diag_matrix : public mm::square_matrix +{ + T& shared_zero = 0; + +public: + using mm::square_matrix::square_matrix; + + // TODO, ordered case: dichotomy search O(log(M)) + // M = parameter pack size + static inline bool constexpr is_in(std::size_t i, std::size_t j) + { + auto t = std::make_tuple(Diags...); + + for(unsigned k(0); k < sizeof...(Diags); ++k) + if ((i - j) == std::get(t)) + return true; + + return false; + } + + virtual T& at(std::size_t row, std::size_t col) override + { + if (is_in(row, col)) + return mm::square_matrix::at(row, col); + + shared_zero = 0; + return shared_zero; + } + + virtual const T& at(std::size_t row, std::size_t col) const override + { + if (is_in(row, col)) + return mm::square_matrix::at(row, col); + + return 0; + } + + + + // TODO, implement limited iterators +}; + +/*template +void constexpr diag_mult(const mm::multi_diag_matrix& a, + const mm::matrix& b, mm::matrix& result) +{ + static_assert(N == P && N == M, "invalid diagonal multiplication"); + + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } +} + +template +mm::matrix operator*( + const mm::multi_diag_matrix& a, + const mm::matrix& b +) { + static_assert(N == P && N == M, "invalid matrix multiplication"); + assert(a.cols() == b.rows()); + + mm::matrix result; + (( + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } + ) ...); + + return result; +}*/ -- cgit v1.2.1