diff options
-rw-r--r-- | include/mmmatrix.hpp | 97 | ||||
-rw-r--r-- | test/matrix_example.cpp | 27 |
2 files changed, 90 insertions, 34 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 6a31610..5285429 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -22,20 +22,24 @@ namespace mm { template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; + /* specialization of basic_matrx for Cols = 1 */ + template<typename T, std::size_t Rows> + class row_vec; + + /* specialization of basic_matrx for Rows = 1 */ + template<typename T, std::size_t Cols> + class col_vec; + + /* shorter name for basic_matrix */ template<typename T, std::size_t Rows, std::size_t Cols> class matrix; + /* specialization of basic_matrix for Rows == Cols */ template<typename T, std::size_t N> class square_matrix; // template<typename T, std::size_t N> // class diag_matrix; - - template<typename T, std::size_t Rows> - class row_vec; - - template<typename T, std::size_t Cols> - class col_vec; } template<typename T, std::size_t Rows, std::size_t Cols> @@ -73,7 +77,7 @@ public: // mathematical operations virtual basic_matrix<T, Cols, Rows> transposed() const; - inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); } + inline basic_matrix<T, Cols, Rows> td() const { return transposed(); } // bool is_invertible() const; // basic_matrix<T, Rows, Cols> inverse() const; @@ -102,8 +106,8 @@ public: } protected: - template<typename Iterator> - basic_matrix(Iterator begin, Iterator end); + template<typename ConstIterator> + basic_matrix(ConstIterator begin, ConstIterator end); private: std::array<T, Rows * Cols> data; @@ -157,8 +161,10 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix( /* protected construtor */ template<typename T, std::size_t Rows, std::size_t Cols> -template<typename Iterator> -mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) { +template<typename ConstIterator> +mm::basic_matrix<T, Rows, Cols>::basic_matrix( + ConstIterator begin, ConstIterator end +) { assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols))); std::copy(begin, end, data.begin()); } @@ -188,8 +194,8 @@ auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) { return data.at(index); } else { return row_vec<T, Rows>( - data.begin() + (index * Cols), - data.begin() + ((index + 1) * Cols) + 1 + data.cbegin() + (index * Cols), + data.cbegin() + ((index + 1) * Cols) + 1 ); } } @@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols> } -/* square matrix specializaiton */ - -template<typename T, std::size_t N> -class mm::square_matrix : public mm::basic_matrix<T, N, N> { -public: - /// in place transpose - void transpose(); - inline void tr() { transpose(); } - - /// in place inverse - void invert(); -}; - - -template<typename T, std::size_t N> -void mm::square_matrix<T, N>::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -} /* row vector specialization */ @@ -336,8 +322,53 @@ public: using mm::basic_matrix<T, 1, Cols>::basic_matrix; }; +/* general specialization (alias) */ template<typename T, std::size_t Rows, std::size_t Cols> class mm::matrix : public mm::basic_matrix<T, Rows, Cols> { public: using mm::basic_matrix<T, Rows, Cols>::basic_matrix; }; + +/* square matrix specializaiton */ +template<typename T, std::size_t N> +class mm::square_matrix : public mm::basic_matrix<T, N, N> { +public: + using mm::basic_matrix<T, N, N>::basic_matrix; + + /// in place transpose + void transpose(); + inline void t() { transpose(); } + + T trace(); + inline T tr() { return trace(); } + + /// in place inverse + void invert(); + + // get the identity of size N + static inline constexpr square_matrix<T, N> identity() { + square_matrix<T, N> i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + + +template<typename T, std::size_t N> +void mm::square_matrix<T, N>::transpose() { + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < row; col++) + std::swap(this->at(row, col), this->at(col, row)); +} + +template<typename T, std::size_t N> +T mm::square_matrix<T, N>::trace() { + T sum = 0; + for (unsigned i = 0; i < N; i++) + sum += this->at(i, i); + + return sum; +} diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 26aeede..469cbff 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -12,18 +12,43 @@ int main(int argc, char *argv[]) { std::cout << "a = \n" << a; std::cout << "b = \n" << b; std::cout << "c = \n" << c; + std::cout << std::endl; // access elements + std::cout << "Access elements" << std::endl; std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl; std::cout << "a[2][0] = " << a[2][0] << std::endl;; + std::cout << std::endl; // basic operations + std::cout << "Basic operations" << std::endl; std::cout << "a + b = \n" << a + b; std::cout << "a - b = \n" << a - b; std::cout << "a * c = \n" << a * c; std::cout << "a * 2 = \n" << a * 2; std::cout << "2 * a = \n" << 2 * a; - std::cout << "tr(a) = \n" << a.trd(); + std::cout << "a.td() = \n" << a.td(); // or a.trasposed(); + std::cout << std::endl; + + // special matrices + mm::square_matrix<std::complex<int>, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; + + std::cout << "Square matrix" << std::endl; + std::cout << "f = \n" << f; + + std::cout << "tr(f) = " << f.tr() /* or f.trace() */ << std::endl; + + f.t(); + std::cout << "after in place transpose f.t(), f = \n" << f; + std::cout << std::endl; + + + auto identity = mm::square_matrix<int, 3>::identity(); + + std::cout << "Identity matrix" << std::endl; + std::cout << "I = \n" << identity; + std::cout << std::endl; + return 0; } |