From b22ab9a435dceb362cb9eab8ec195d908fd8d5e9 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 2 Mar 2019 11:22:17 +0100 Subject: Update matrix test, add square matrix trace and fix comments --- include/mmmatrix.hpp | 97 ++++++++++++++++++++++++++++++++----------------- test/matrix_example.cpp | 27 +++++++++++++- 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 6a31610..5285429 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -22,20 +22,24 @@ namespace mm { template class basic_matrix; + /* specialization of basic_matrx for Cols = 1 */ + template + class row_vec; + + /* specialization of basic_matrx for Rows = 1 */ + template + class col_vec; + + /* shorter name for basic_matrix */ template class matrix; + /* specialization of basic_matrix for Rows == Cols */ template class square_matrix; // template // class diag_matrix; - - template - class row_vec; - - template - class col_vec; } template @@ -73,7 +77,7 @@ public: // mathematical operations virtual basic_matrix transposed() const; - inline basic_matrix trd() const { return transposed(); } + inline basic_matrix td() const { return transposed(); } // bool is_invertible() const; // basic_matrix inverse() const; @@ -102,8 +106,8 @@ public: } protected: - template - basic_matrix(Iterator begin, Iterator end); + template + basic_matrix(ConstIterator begin, ConstIterator end); private: std::array data; @@ -157,8 +161,10 @@ mm::basic_matrix::basic_matrix( /* protected construtor */ template -template -mm::basic_matrix::basic_matrix(Iterator begin, Iterator end) { +template +mm::basic_matrix::basic_matrix( + ConstIterator begin, ConstIterator end +) { assert(static_cast(std::distance(begin, end)) >= ((Rows * Cols))); std::copy(begin, end, data.begin()); } @@ -188,8 +194,8 @@ auto mm::basic_matrix::operator[](std::size_t index) { return data.at(index); } else { return row_vec( - data.begin() + (index * Cols), - data.begin() + ((index + 1) * Cols) + 1 + data.cbegin() + (index * Cols), + data.cbegin() + ((index + 1) * Cols) + 1 ); } } @@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix } -/* square matrix specializaiton */ - -template -class mm::square_matrix : public mm::basic_matrix { -public: - /// in place transpose - void transpose(); - inline void tr() { transpose(); } - - /// in place inverse - void invert(); -}; - - -template -void mm::square_matrix::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -} /* row vector specialization */ @@ -336,8 +322,53 @@ public: using mm::basic_matrix::basic_matrix; }; +/* general specialization (alias) */ template class mm::matrix : public mm::basic_matrix { public: using mm::basic_matrix::basic_matrix; }; + +/* square matrix specializaiton */ +template +class mm::square_matrix : public mm::basic_matrix { +public: + using mm::basic_matrix::basic_matrix; + + /// in place transpose + void transpose(); + inline void t() { transpose(); } + + T trace(); + inline T tr() { return trace(); } + + /// in place inverse + void invert(); + + // get the identity of size N + static inline constexpr square_matrix identity() { + square_matrix i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + + +template +void mm::square_matrix::transpose() { + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < row; col++) + std::swap(this->at(row, col), this->at(col, row)); +} + +template +T mm::square_matrix::trace() { + T sum = 0; + for (unsigned i = 0; i < N; i++) + sum += this->at(i, i); + + return sum; +} diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 26aeede..469cbff 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -12,18 +12,43 @@ int main(int argc, char *argv[]) { std::cout << "a = \n" << a; std::cout << "b = \n" << b; std::cout << "c = \n" << c; + std::cout << std::endl; // access elements + std::cout << "Access elements" << std::endl; std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl; std::cout << "a[2][0] = " << a[2][0] << std::endl;; + std::cout << std::endl; // basic operations + std::cout << "Basic operations" << std::endl; std::cout << "a + b = \n" << a + b; std::cout << "a - b = \n" << a - b; std::cout << "a * c = \n" << a * c; std::cout << "a * 2 = \n" << a * 2; std::cout << "2 * a = \n" << 2 * a; - std::cout << "tr(a) = \n" << a.trd(); + std::cout << "a.td() = \n" << a.td(); // or a.trasposed(); + std::cout << std::endl; + + // special matrices + mm::square_matrix, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; + + std::cout << "Square matrix" << std::endl; + std::cout << "f = \n" << f; + + std::cout << "tr(f) = " << f.tr() /* or f.trace() */ << std::endl; + + f.t(); + std::cout << "after in place transpose f.t(), f = \n" << f; + std::cout << std::endl; + + + auto identity = mm::square_matrix::identity(); + + std::cout << "Identity matrix" << std::endl; + std::cout << "I = \n" << identity; + std::cout << std::endl; + return 0; } -- cgit v1.2.1