diff options
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r-- | include/mm/mmmatrix.hpp | 111 |
1 files changed, 73 insertions, 38 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index f88e90c..5eb0171 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -77,6 +77,9 @@ public: template<typename U, std::size_t ORows, std::size_t OCols, class Grid> friend class mm::iter::basic_iterator; + template<typename U, std::size_t ON, class Grid> + friend class mm::iter::diag_iterator; + //template<typename U, std::size_t ORows, std::size_t OCols, class Grid> //friend class mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>; @@ -183,11 +186,6 @@ void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) { template<typename T, std::size_t Rows, std::size_t Cols> class mm::matrix { -protected: - - // shallow construction - matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid = nullptr, bool tr = false) : M(grid), transposed(tr) {} - public: //template<typename U, std::size_t ORows, std::size_t OCols> @@ -228,7 +226,7 @@ public: * Transposition */ - matrix<T, Rows, Cols>& transpose() + matrix<T, Rows, Cols>& transpose_d() { transposed = !transposed; return *this; @@ -242,7 +240,7 @@ public: return m; } - inline matrix<T, Rows, Cols>& t() + inline matrix<T, Rows, Cols>& td() { return transpose(); } @@ -315,12 +313,12 @@ public: mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) { - return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, !transposed); + return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed); } mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const { - return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, !transposed); + return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed); } /* @@ -358,7 +356,10 @@ protected: std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> M; - matrix<T, Rows, Cols> shallow_cpy() + // shallow construction + matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {} + + matrix<T, Rows, Cols> shallow_cpy() const { return matrix<T, Rows, Cols>(M, transposed); } @@ -408,6 +409,7 @@ mm::matrix<T, M, N> operator*( const mm::matrix<T, M, P1>& a, const mm::matrix<T, P2, N>& b ) { + // TODO, adjust asserts for transposed cases static_assert(P1 == P2, "invalid matrix multiplication"); assert(a.cols() == b.rows()); @@ -421,34 +423,6 @@ mm::matrix<T, M, N> operator*( return result; } -// transposed multiplication -/*template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N> -mm::matrix<T, M, N> operator*( - const mm::matrix<T, P1, M>& a, - const mm::matrix<T, P2, N>& b -) { - static_assert(P1 == P2, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix<T, M, N> result; - mm::matrix<T, P2, N> bt = b.t(); // weak transposition - - for (unsigned row = 0; row < M; row++) - for (unsigned col = 0; col < N; col++) - result.at(row, col) = a[row] * bt[col]; // scalar product - - return result; -}*/ - - -/*template<typename T, std::size_t N> -void mm::square_matrix<T, N>::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -}*/ - - /* * Matrix operator << */ @@ -466,3 +440,64 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) { return os; } + +/* + * Square matrix + */ + +template<typename T, std::size_t N> +class mm::square_matrix : public mm::matrix<T, N, N> +{ +public: + + using mm::matrix<T, N, N>::matrix; + + using diag_iterator = mm::iter::diag_iterator<T, N, mm::basic_matrix<T, N, N>>; + + using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; + + T trace(); + inline T tr() { return trace(); } + + mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) + { + return diag_iterator(*(this->M), row, 0); + } + + mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const + { + return const_diag_iterator(*(this->M), row, N); + } + + // TODO, determinant + + /// in place inverse + // TODO, det != 0 + // TODO, use gauss jordan for invertible ones + //void invert();, TODO, section algorithm + + /* + * Generate the identity + */ + + static inline constexpr mm::square_matrix<T, N> identity() { + mm::square_matrix<T, N> i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + +template<typename T, std::size_t N> +T mm::square_matrix<T, N>::trace() +{ + T sum = 0; + for (const auto& x : diag_beg()) + sum += x; + + return sum; +} + + |