diff options
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r-- | include/mm/mmmatrix.hpp | 330 |
1 files changed, 198 insertions, 132 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 7834a3e..9910273 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -23,8 +23,32 @@ namespace mm { template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; + /* specialization of basic_matrix for Rows == Cols */ + template<typename T, std::size_t N> + class square_interface; + + /* + * Most general type of matrix iterator + * IterType = Row, Column, Diagonal + * Grid = constness of mm::basic_matrix + */ + template<typename T, std::size_t Rows, std::size_t Cols, int IterType, class Grid> + class vector_iterator; + + /* + * Access methods interface + */ + + template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> + class access; + + /* specialisations */ + + template<typename T, std::size_t Rows, std::size_t Cols> + class matrix; // simple matrix format + template<typename T, std::size_t Rows, std::size_t Cols> - class transposed_matrix; + class t_matrix; // transposed matrix format /* specialization of basic_matrx for Cols = 1 */ template<typename T, std::size_t Rows> @@ -32,50 +56,41 @@ namespace mm { /* specialization of basic_matrx for Rows = 1 */ template<typename T, std::size_t Cols> - class col_vec; + class col_vec; // transposed version of row_vec - /* shorter name for basic_matrix */ - template<typename T, std::size_t Rows, std::size_t Cols> - class matrix; - - /* specialization of basic_matrix for Rows == Cols */ template<typename T, std::size_t N> class square_matrix; - template<typename T, std::size_t N, int K = 0> - class diagonal_matrix; + template<typename T, std::size_t N> + class t_square_matrix; - /* - * Iterators - */ + /* specialisation of a square_matrix for a sub-diagonal composed matrix */ + template<typename T, std::size_t N, std::size_t K = 0> + class diag_matrix; - /* - * Most abstract type of iterator - * IterType = Row, Col, Diag - * Grid = constness of mm::basic_matrix - */ - template<typename T, std::size_t Rows, std::size_t Cols, int IterType, template <typename, std::size_t, std::size_t> class Grid> - class vector_iterator; + template<typename T, std::size_t N, std::size_t K = 0> + class t_diag_matrix; } +// TODO, short term solution #define MM_ROW_ITER 0 #define MM_COL_ITER 1 #define MM_DIAG_ITER 2 -template<typename T, std::size_t Rows, std::size_t Cols, int IterType, template <typename, std::size_t, std::size_t> class Grid> +template<typename T, std::size_t Rows, std::size_t Cols, int IterType, class Grid> class mm::vector_iterator { std::size_t index; // variable index - Grid<T, Rows, Cols>& M; + Grid& M; const int position; // fixed index, negative too for diagonal iterator public: - template<typename U, std::size_t ORows, std::size_t OCols, class OIterType, template <typename, std::size_t, std::size_t> class OGrid> + template<typename U, std::size_t ORows, std::size_t OCols, class OIterType, class OGrid> friend class vector_iterator; - vector_iterator(Grid<T, Rows, Cols>& M, int position, std::size_t index = 0); + vector_iterator(Grid& M, int position, std::size_t index = 0); mm::vector_iterator<T, Rows, Cols, IterType, Grid> operator++() { @@ -130,27 +145,75 @@ public: namespace mm { template<typename T, std::size_t Rows, std::size_t Cols> - using row_iterator = vector_iterator<T, Rows, Cols, MM_ROW_ITER, mm::basic_matrix>; + using row_iterator = vector_iterator<T, Rows, Cols, MM_ROW_ITER, mm::basic_matrix<T, Rows, Cols>>; template<typename T, std::size_t Rows, std::size_t Cols> - using col_iterator = vector_iterator<T, Rows, Cols, MM_COL_ITER, mm::basic_matrix>; + using col_iterator = vector_iterator<T, Rows, Cols, MM_COL_ITER, mm::basic_matrix<T, Rows, Cols>>; template<typename T, std::size_t Rows, std::size_t Cols> - using const_row_iterator = vector_iterator<T, Rows, Cols, MM_ROW_ITER, const mm::basic_matrix>; + using const_row_iterator = vector_iterator<T, Rows, Cols, MM_ROW_ITER, std::add_const<mm::basic_matrix<T, Rows, Cols>>>; template<typename T, std::size_t Rows, std::size_t Cols> - using const_col_iterator = vector_iterator<T, Rows, Cols, MM_COL_ITER, const mm::basic_matrix>; + using const_col_iterator = vector_iterator<T, Rows, Cols, MM_COL_ITER, std::add_const<mm::basic_matrix<T, Rows, Cols>>>; template<typename T, std::size_t N> - using diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, mm::basic_matrix>; + using diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, mm::basic_matrix<T, N, N>>; template<typename T, std::size_t N> - using const_diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, const mm::basic_matrix>; + using const_diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, std::add_const<mm::basic_matrix<T, N, N>>>; } +/* + * Accessors + */ +template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> +class mm::access +{ +public: + + //access(mm::basic_matrix<T, Rows, Cols>& ref) : M(ref) {} + + T& at(std::size_t row, std::size_t col); + const T& at(std::size_t row, std::size_t col) const; + + auto operator[](std::size_t index); + auto operator[](std::size_t index) const; + +//private: +// mm::basic_matrix<T, Rows, Cols>& M; +protected: + std::array<T, Rows * Cols> data; +}; /* - * Matrix class + * Square interface + */ + +template<typename T, std::size_t N> +class mm::square_interface { +public: + + //square_interface(mm:basic_matrix<T, N, N>& _M) : M(_M) {} + + T trace(); + inline T tr() { return trace(); } + + // TODO, determinant + + /// in place inverse + // TODO, det != 0 + // TODO, use gauss jordan for invertible ones + //void invert();, TODO, section algorithm + +//private: +// mm:basic_matrix<T, N, N>& M; // one information more than mm::matrix ! +protected: + std::array<T, N * N> data; +}; + + +/* + * Matrix class, no access methods */ template<typename T, std::size_t Rows, std::size_t Cols> @@ -161,9 +224,6 @@ public: template<typename U, std::size_t ORows, std::size_t OCols> friend class mm::basic_matrix; - template<typename U, std::size_t ORows, std::size_t OCols> - friend class mm::vector_iterator; - static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; @@ -181,26 +241,19 @@ public: basic_matrix(const basic_matrix<T, ORows, OCols>& other); // access data, basic definition - virtual T& at(std::size_t row, std::size_t col); - virtual const T& at(std::size_t row, std::size_t col) const; + //virtual T& at(std::size_t row, std::size_t col); + //virtual const T& at(std::size_t row, std::size_t col) const; // allows to access a matrix M at row j col k with M[j][k] - auto operator[](std::size_t index); - - virtual auto row_begin(std::size_t index) - { - return mm::row_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); - } + //auto operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations - // TODO, simply switch iteration mode //virtual basic_matrix<T, Cols, Rows> transposed() const; //inline basic_matrix<T, Cols, Rows> td() const { return transposed(); } - /// downcast to square matrix static inline constexpr bool is_square() { return (Rows == Cols); } inline constexpr square_matrix<T, Rows> to_square() const { @@ -208,7 +261,6 @@ public: return static_cast<square_matrix<T, Rows>>(*this); } - /// downcast to row_vector static inline constexpr bool is_row_vec() { return (Cols == 1); } inline constexpr row_vec<T, Rows> to_row_vec() const { @@ -227,7 +279,7 @@ protected: template<typename ConstIterator> basic_matrix(ConstIterator begin, ConstIterator end); -private: +//private: std::array<T, Rows * Cols> data; }; @@ -290,7 +342,7 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix( /* member functions */ -template<typename T, std::size_t Rows, std::size_t Cols> +/*template<typename T, std::size_t Rows, std::size_t Cols> T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) { assert(row < Rows); // "out of row bound" assert(col < Cols); // "out of column bound" @@ -311,14 +363,14 @@ auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) { if constexpr (is_row_vec() || is_col_vec()) { return data.at(index); } else { - return row_begin(index); + return mm::row_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); - /*return row_vec<T, Rows>( - data.cbegin() + (index * Cols), - data.cbegin() + ((index + 1) * Cols) + 1 - );*/ + //return row_vec<T, Rows>( + // data.cbegin() + (index * Cols), + // data.cbegin() + ((index + 1) * Cols) + 1 + ); } -} +}*/ template<typename T, std::size_t Rows, std::size_t Cols> @@ -351,8 +403,8 @@ mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() const { }*/ -/* operator overloading */ -template<typename T, std::size_t Rows, std::size_t Cols> +/* TODO, operator overloading */ +/*template<typename T, std::size_t Rows, std::size_t Cols> mm::basic_matrix<T, Rows, Cols> operator+( const mm::basic_matrix<T, Rows, Cols>& a, const mm::basic_matrix<T, Rows, Cols>& b @@ -424,93 +476,55 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols> } return os; -} - - +}*/ /* * derivated classes */ -/* row vector specialization */ -template<typename T, std::size_t Rows> -class mm::row_vec : public mm::basic_matrix<T, Rows, 1> { -public: - using mm::basic_matrix<T, Rows, 1>::basic_matrix; -}; - -/* column vector specialization */ -template<typename T, std::size_t Cols> -class mm::col_vec : public mm::basic_matrix<T, 1, Cols> { -public: - using mm::basic_matrix<T, 1, Cols>::basic_matrix; -}; - -/* general specialization (alias) */ +// simple format matrix template<typename T, std::size_t Rows, std::size_t Cols> -class mm::matrix : public mm::basic_matrix<T, Rows, Cols> { +class mm::matrix : public mm::basic_matrix<T, Rows, Cols>, virtual public mm::access<T, Rows, Cols, true> +{ public: using mm::basic_matrix<T, Rows, Cols>::basic_matrix; }; -/* - * transposed matrix format - */ - +// transposed matrix template<typename T, std::size_t Rows, std::size_t Cols> -class mm::transposed_matrix : public mm::basic_matrix<T, Rows, Cols> +class mm::t_matrix : public mm::basic_matrix<T, Rows, Cols>, virtual public mm::access<T, Rows, Cols, false> { public: using mm::basic_matrix<T, Rows, Cols>::basic_matrix; +}; - virtual T& at(std::size_t row, std::size_t col) override - { - return mm::basic_matrix<T, Rows, Cols>::at(col, row); - } +/* row vector specialization */ +template<typename T, std::size_t Rows> +class mm::row_vec : public mm::matrix<T, Rows, 1> { +public: + using mm::matrix<T, Rows, 1>::matrix; - virtual const T& at(std::size_t row, std::size_t col) const override - { - return mm::basic_matrix<T, Rows, Cols>::at(col, row); - } + // TODO, begin, end +}; - // allows to access a matrix M at row j col k with M[j][k] - virtual auto row_begin(std::size_t index) override - { - return mm::col_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); - } +/* column vector specialization */ +template<typename T, std::size_t Cols> +class mm::col_vec : public mm::t_matrix<T, 1, Cols> { +public: + using mm::t_matrix<T, 1, Cols>::t_matrix; + + // TODO, begin, end }; /* square matrix specialization */ template<typename T, std::size_t N> -class mm::square_matrix : public mm::basic_matrix<T, N, N> { +class mm::square_matrix : public mm::matrix<T, N, N> , virtual public mm::square_interface<T, N> { public: - using mm::basic_matrix<T, N, N>::basic_matrix; - - /// in place transpose - //void transpose(); - //inline void t() { transpose(); } - - T trace(); - inline T tr() { return trace(); } - - /// in place inverse - // TODO, det != 0 - // TODO, use gauss jordan for invertible ones - void invert(); - - - // TODO, downcast to K-diagonal, user defined cast - /*template<int K> - operator mm::diagonal_matrix<T, N, K>() const - { - // it's always possible to do it bidirectionally, - // without loosing information - return reinterpret_cast<mm::diagonal_matrix<T, N, K>>(*this); - }*/ + using mm::matrix<T, N, N>::matrix; // get the identity of size N static inline constexpr square_matrix<T, N> identity() { - square_matrix<T, N> i; + mm::square_matrix<T, N> i; for (unsigned row = 0; row < N; row++) for (unsigned col = 0; col < N; col++) i.at(row, col) = (row == col) ? 1 : 0; @@ -519,12 +533,18 @@ public: } }; +template<typename T, std::size_t N> +class mm::t_square_matrix : virtual public mm::t_matrix<T, N, N>, virtual public mm::square_interface<T, N> { +public: + using mm::t_matrix<T, N, N>::t_matrix; +}; + /* * K-diagonal square matrix format * K is bounded between ]-N, N[ */ -template<typename T, std::size_t N, int K> +/*template<typename T, std::size_t N, std::size_t K> class mm::diagonal_matrix : public mm::square_matrix<T, N> { public: @@ -532,7 +552,7 @@ public: // TODO, redefine at, operator[] // TODO, matrix multiplication -}; +};*/ /*template<typename T, std::size_t N> void mm::square_matrix<T, N>::transpose() { @@ -541,20 +561,10 @@ void mm::square_matrix<T, N>::transpose() { std::swap(this->at(row, col), this->at(col, row)); }*/ -template<typename T, std::size_t N> -T mm::square_matrix<T, N>::trace() { - - T sum = 0; - for (mm::diag_iterator<T, N> it(*this, 0); it.ok(); ++it) - sum += *it; - - return sum; -} - /* Iterators implementation */ -template<typename T, std::size_t Rows, std::size_t Cols, int IterType, template <typename, std::size_t, std::size_t> class Grid> -mm::vector_iterator<T, Rows, Cols, IterType, Grid>::vector_iterator(Grid<T, Rows, Cols>& _M, int pos, std::size_t i) +template<typename T, std::size_t Rows, std::size_t Cols, int IterType, class Grid> +mm::vector_iterator<T, Rows, Cols, IterType, Grid>::vector_iterator(Grid& _M, int pos, std::size_t i) : index(i), M(_M), position(pos) { if constexpr (IterType == MM_ROW_ITER) { @@ -566,7 +576,7 @@ mm::vector_iterator<T, Rows, Cols, IterType, Grid>::vector_iterator(Grid<T, Rows } } -template<typename T, std::size_t Rows, std::size_t Cols, int IterType, template <typename, std::size_t, std::size_t> class Grid> +template<typename T, std::size_t Rows, std::size_t Cols, int IterType, class Grid> T& mm::vector_iterator<T, Rows, Cols, IterType, Grid>::operator*() const { if constexpr (IterType == MM_ROW_ITER) @@ -579,7 +589,7 @@ T& mm::vector_iterator<T, Rows, Cols, IterType, Grid>::operator*() const M.data[index * Cols + (index - position)]; } -template<typename T, std::size_t Rows, std::size_t Cols, int IterType, template <typename, std::size_t, std::size_t> class Grid> +template<typename T, std::size_t Rows, std::size_t Cols, int IterType, class Grid> T& mm::vector_iterator<T, Rows, Cols, IterType, Grid>::operator[](std::size_t i) { if constexpr (IterType == MM_ROW_ITER) @@ -592,3 +602,59 @@ T& mm::vector_iterator<T, Rows, Cols, IterType, Grid>::operator[](std::size_t i) M.data[i * Cols + (i - position)]; } +/* + * Accessors implementation + */ + +template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> +T& mm::access<T, Rows, Cols, Regular>::at(std::size_t row, std::size_t col) +{ + if constexpr (Regular) + return data[row * Cols + col]; + else + return data[col * Cols + row]; // transpose +} + +template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> +const T& mm::access<T, Rows, Cols, Regular>::at(std::size_t row, std::size_t col) const +{ + if constexpr (Regular) + return data[row * Cols + col]; + else + return data[col * Cols + row]; // transpose +} + +template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> +auto mm::access<T, Rows, Cols, Regular>::operator[](std::size_t index) +{ + if constexpr (this->is_row_vec() || this->is_col_vec()) + return data.at(index); + else if (Regular) + return mm::row_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); + else + return mm::col_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); +} + +template<typename T, std::size_t Rows, std::size_t Cols, bool Regular> +auto mm::access<T, Rows, Cols, Regular>::operator[](std::size_t index) const +{ + if constexpr (this->is_row_vec() || this->is_col_vec()) + return data.at(index); + else if (Regular) + return mm::const_row_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); + else + return mm::const_col_iterator<T, Rows, Cols>(*this, static_cast<int>(index)); +} + +/* Square interface implementation */ + +template<typename T, std::size_t N> +T mm::square_interface<T, N>::trace() +{ + T sum = 0; + for (mm::diag_iterator<T, N> it(*this, 0); it.ok(); ++it) + sum += *it; + + return sum; +} + |