From b22ab9a435dceb362cb9eab8ec195d908fd8d5e9 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 2 Mar 2019 11:22:17 +0100 Subject: Update matrix test, add square matrix trace and fix comments --- include/mmmatrix.hpp | 97 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 64 insertions(+), 33 deletions(-) (limited to 'include/mmmatrix.hpp') diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 6a31610..5285429 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -22,20 +22,24 @@ namespace mm { template class basic_matrix; + /* specialization of basic_matrx for Cols = 1 */ + template + class row_vec; + + /* specialization of basic_matrx for Rows = 1 */ + template + class col_vec; + + /* shorter name for basic_matrix */ template class matrix; + /* specialization of basic_matrix for Rows == Cols */ template class square_matrix; // template // class diag_matrix; - - template - class row_vec; - - template - class col_vec; } template @@ -73,7 +77,7 @@ public: // mathematical operations virtual basic_matrix transposed() const; - inline basic_matrix trd() const { return transposed(); } + inline basic_matrix td() const { return transposed(); } // bool is_invertible() const; // basic_matrix inverse() const; @@ -102,8 +106,8 @@ public: } protected: - template - basic_matrix(Iterator begin, Iterator end); + template + basic_matrix(ConstIterator begin, ConstIterator end); private: std::array data; @@ -157,8 +161,10 @@ mm::basic_matrix::basic_matrix( /* protected construtor */ template -template -mm::basic_matrix::basic_matrix(Iterator begin, Iterator end) { +template +mm::basic_matrix::basic_matrix( + ConstIterator begin, ConstIterator end +) { assert(static_cast(std::distance(begin, end)) >= ((Rows * Cols))); std::copy(begin, end, data.begin()); } @@ -188,8 +194,8 @@ auto mm::basic_matrix::operator[](std::size_t index) { return data.at(index); } else { return row_vec( - data.begin() + (index * Cols), - data.begin() + ((index + 1) * Cols) + 1 + data.cbegin() + (index * Cols), + data.cbegin() + ((index + 1) * Cols) + 1 ); } } @@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix } -/* square matrix specializaiton */ - -template -class mm::square_matrix : public mm::basic_matrix { -public: - /// in place transpose - void transpose(); - inline void tr() { transpose(); } - - /// in place inverse - void invert(); -}; - - -template -void mm::square_matrix::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -} /* row vector specialization */ @@ -336,8 +322,53 @@ public: using mm::basic_matrix::basic_matrix; }; +/* general specialization (alias) */ template class mm::matrix : public mm::basic_matrix { public: using mm::basic_matrix::basic_matrix; }; + +/* square matrix specializaiton */ +template +class mm::square_matrix : public mm::basic_matrix { +public: + using mm::basic_matrix::basic_matrix; + + /// in place transpose + void transpose(); + inline void t() { transpose(); } + + T trace(); + inline T tr() { return trace(); } + + /// in place inverse + void invert(); + + // get the identity of size N + static inline constexpr square_matrix identity() { + square_matrix i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + + +template +void mm::square_matrix::transpose() { + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < row; col++) + std::swap(this->at(row, col), this->at(col, row)); +} + +template +T mm::square_matrix::trace() { + T sum = 0; + for (unsigned i = 0; i < N; i++) + sum += this->at(i, i); + + return sum; +} -- cgit v1.2.1