diff options
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r-- | include/mm/mmmatrix.hpp | 198 |
1 files changed, 142 insertions, 56 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 0b7a40b..facbb48 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -23,41 +23,30 @@ namespace mm { + /* basic grid structure */ + template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; - /* specialisations */ + /* basic wrapper */ template<typename T, std::size_t Rows, std::size_t Cols> class matrix; // simple matrix format + + /* specialisations */ - /* specialization of basic_matrx for Cols = 1 */ - template<typename T, std::size_t Rows> - class row_vec; - - /* specialization of basic_matrx for Rows = 1 */ - template<typename T, std::size_t Cols> - class col_vec; // transposed version of row_vec + /* specialization of a matrix */ + template<typename T, std::size_t N> + class vector; // by default, set a column vector template<typename T, std::size_t N> class square_matrix; /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template<typename T, std::size_t N, std::size_t K = 0> - class diag_matrix; + template<typename T, std::size_t N, signed long ... Diags> + class multi_diag_matrix; } - -/*namespace mm { - - template<typename T, std::size_t N> - using diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, mm::basic_matrix<T, N, N>>; - - template<typename T, std::size_t N> - using const_diag_iterator = vector_iterator<typename std::add_const<T>::type, N, N, MM_DIAG_ITER, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; -}*/ - - /* * Matrix class, no access methods */ @@ -211,11 +200,6 @@ public: other.M = nullptr; } - // copy from another matrix - /*template<std::size_t ORows, std::size_t OCols> - matrix(const matrix<T, ORows, OCols>& other) - : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>(*other.M)), transposed(other.transposed) {} */ - matrix<T, Rows, Cols> operator=(const basic_matrix<T, Rows, Cols>& other) // deep copy { *M = *other.M; @@ -232,12 +216,9 @@ public: return *this; } - matrix<T, Rows, Cols> transpose() const + const matrix<T, Rows, Cols> transpose() const { - auto m = shallow_cpy(); - m.transposed = !transposed; - - return m; + return matrix<T, Rows, Cols>(M, !transposed); } inline matrix<T, Rows, Cols>& td() @@ -277,28 +258,24 @@ public: return static_cast<square_matrix<T, Rows>>(*this); } - /// downcast to row_vector - static inline constexpr bool is_row_vec() { return (Cols == 1); } - inline constexpr row_vec<T, Rows> to_row_vec() const { - static_assert(is_row_vec()); - return static_cast<row_vec<T, Rows>>(*this); - } - /// downcast to col_vector - static inline constexpr bool is_col_vec() { return (Rows == 1); } - inline constexpr col_vec<T, Cols> to_col_vec() const { - static_assert(is_col_vec()); - return static_cast<col_vec<T, Cols>>(*this); + static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } + inline vector<T, Cols> to_vector() const { + if constexpr(Cols == 1) + return static_cast<vector<T, Rows>>(*this); + else if (Rows == 1) + return vector<T, Cols>(*this); // copy into column vector + } /* Accessors */ - T& at(std::size_t row, std::size_t col) + virtual T& at(std::size_t row, std::size_t col) { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } - const T& at(std::size_t row, std::size_t col) const + virtual const T& at(std::size_t row, std::size_t col) const { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } @@ -311,12 +288,12 @@ public: return (transposed) ? Rows : Cols; } - mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) + virtual mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) { return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed); } - mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const + virtual mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const { return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed); } @@ -358,12 +335,7 @@ protected: // shallow construction matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {} - - matrix<T, Rows, Cols> shallow_cpy() const - { - return matrix<T, Rows, Cols>(M, transposed); - } - + private: bool transposed; @@ -416,7 +388,7 @@ mm::matrix<T, M, N> operator*( mm::matrix<T, M, N> result; const mm::matrix<T, P2, N> bt = b.t(); // weak transposition - npdebug("Calling *") + //npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) @@ -444,6 +416,33 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) { } /* + * Vector, TODO better manage column and row + */ + +template<typename T, std::size_t N> +class mm::vector : public mm::matrix<T, N, 1> +{ +public: + + using mm::matrix<T, N, 1>::matrix; + + vector(std::initializer_list<T> l) + : mm::matrix<T, N, 1>(l) {} +}; + +template<typename T> +mm::vector<T, 3> operator^(const mm::vector<T, 3>& v, const mm::vector<T, 3>& w) +{ + mm::vector<T, 3> out; + + out[0] = v[1] * w[2] - v[2] * w[2]; + out[1] = v[2] * w[0] - v[0] * w[2]; + out[2] = v[0] * w[1] - v[1] * w[0]; + + return out; +} + +/* * Square matrix */ @@ -458,15 +457,15 @@ public: using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; - T trace(); + virtual T trace(); inline T tr() { return trace(); } - mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) + virtual mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) { return diag_iterator(*(this->M), row, 0); } - mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const + virtual mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const { return const_diag_iterator(*(this->M), row, N); } @@ -502,4 +501,91 @@ T mm::square_matrix<T, N>::trace() return sum; } +// TODO, static assert, for all: Diags > -N, Diags < N +// TODO, force Diags to be ordered +template<typename T, std::size_t N, signed long ... Diags> +class mm::multi_diag_matrix : public mm::square_matrix<T, N> +{ + T& shared_zero = 0; + +public: + using mm::square_matrix<T, N>::square_matrix; + + // TODO, ordered case: dichotomy search O(log(M)) + // M = parameter pack size + static inline bool constexpr is_in(std::size_t i, std::size_t j) + { + auto t = std::make_tuple(Diags...); + + for(unsigned k(0); k < sizeof...(Diags); ++k) + if ((i - j) == std::get<k>(t)) + return true; + + return false; + } + + virtual T& at(std::size_t row, std::size_t col) override + { + if (is_in(row, col)) + return mm::square_matrix<T, N>::at(row, col); + + shared_zero = 0; + return shared_zero; + } + + virtual const T& at(std::size_t row, std::size_t col) const override + { + if (is_in(row, col)) + return mm::square_matrix<T, N>::at(row, col); + + return 0; + } + + + + // TODO, implement limited iterators +}; + +/*template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> +void constexpr diag_mult(const mm::multi_diag_matrix<T, N, Diags...>& a, + const mm::matrix<T, P, M>& b, mm::matrix<T, M, N>& result) +{ + static_assert(N == P && N == M, "invalid diagonal multiplication"); + + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } +} + +template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> +mm::matrix<T, M, N> operator*( + const mm::multi_diag_matrix<T, N, Diags...>& a, + const mm::matrix<T, P, M>& b +) { + static_assert(N == P && N == M, "invalid matrix multiplication"); + assert(a.cols() == b.rows()); + + mm::matrix<T, M, N> result; + (( + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } + ) ...); + + return result; +}*/ |