diff options
-rw-r--r-- | include/mmmatrix.hpp | 59 | ||||
-rw-r--r-- | test/matrix_example.cpp | 4 |
2 files changed, 35 insertions, 28 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 8f60312..6a31610 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -12,6 +12,7 @@ #pragma once #include <iostream> +#include <iomanip> #include <cstring> #include <cassert> #include <initializer_list> @@ -42,6 +43,9 @@ class mm::basic_matrix { public: using type = T; + template<typename U, std::size_t ORows, std::size_t OCols> + friend class mm::basic_matrix; + static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; @@ -58,20 +62,17 @@ public: template<std::size_t ORows, std::size_t OCols> basic_matrix(const basic_matrix<T, ORows, OCols>& other); - // copy or move from 2D array - basic_matrix(const T (& values)[Rows][Cols]); - basic_matrix(T (&& values)[Rows][Cols]); - // access data T& at(std::size_t row, std::size_t col); const T& at(std::size_t row, std::size_t col) const; - auto&& operator[](std::size_t index); + // allows to access a matrix M at row j col k with M[j][k] + auto operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations - basic_matrix<T, Cols, Rows> transposed() const; + virtual basic_matrix<T, Cols, Rows> transposed() const; inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); } // bool is_invertible() const; @@ -79,7 +80,7 @@ public: /// downcast to square matrix - inline constexpr bool is_square() const { return (Rows == Cols); } + static inline constexpr bool is_square() { return (Rows == Cols); } inline constexpr square_matrix<T, Rows> to_square() const { static_assert(is_square()); return static_cast<square_matrix<T, Rows>>(*this); @@ -87,19 +88,23 @@ public: /// downcast to row_vector - inline constexpr bool is_row_vec() const { return (Cols == 1); } + static inline constexpr bool is_row_vec() { return (Cols == 1); } inline constexpr row_vec<T, Rows> to_row_vec() const { static_assert(is_row_vec()); return static_cast<row_vec<T, Rows>>(*this); } /// downcast to col_vector - inline constexpr bool is_col_vec() const { return (Rows == 1); } + static inline constexpr bool is_col_vec() { return (Rows == 1); } inline constexpr col_vec<T, Cols> to_col_vec() const { static_assert(is_col_vec()); return static_cast<col_vec<T, Cols>>(*this); } +protected: + template<typename Iterator> + basic_matrix(Iterator begin, Iterator end); + private: std::array<T, Rows * Cols> data; }; @@ -150,13 +155,13 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix( this->at(row, col) = other.at(row, col); } +/* protected construtor */ template<typename T, std::size_t Rows, std::size_t Cols> -mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols]) - : data(values) {} - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols]) - : data(std::forward<decltype(values)>(values)) {} +template<typename Iterator> +mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) { + assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols))); + std::copy(begin, end, data.begin()); +} /* member functions */ @@ -178,17 +183,15 @@ const T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) c } template<typename T, std::size_t Rows, std::size_t Cols> -auto&& mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) { - if constexpr (is_row_vec()) { - static_assert(index < Rows); - return data[index]; - } else if constexpr (is_col_vec()) { - static_assert(index < Cols); - return data[index]; +auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) { + if constexpr (is_row_vec() || is_col_vec()) { + return data.at(index); + } else { + return row_vec<T, Rows>( + data.begin() + (index * Cols), + data.begin() + ((index + 1) * Cols) + 1 + ); } - - // TODO: fix - // return row_vec<T, Rows>(std::move(data[index])); } template<typename T, std::size_t Rows, std::size_t Cols> @@ -283,14 +286,14 @@ mm::basic_matrix<T, Rows, Cols> operator-( return a + (static_cast<T>(-1) * b); } -template<typename T, std::size_t Rows, std::size_t Cols> +template<typename T, std::size_t Rows, std::size_t Cols, unsigned NumW = 3> std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>& m) { for (unsigned row = 0; row < Rows; row++) { os << "[ "; for (unsigned col = 0; col < (Cols -1); col++) { - os << m.at(row, col) << ", "; + os << std::setw(NumW) << m.at(row, col) << ", "; } - os << m.at(row, (Cols -1)) << " ]\n"; + os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n"; } return os; diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 4cf1863..26aeede 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -12,6 +12,10 @@ int main(int argc, char *argv[]) { std::cout << "a = \n" << a; std::cout << "b = \n" << b; std::cout << "c = \n" << c; + + // access elements + std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl; + std::cout << "a[2][0] = " << a[2][0] << std::endl;; // basic operations std::cout << "a + b = \n" << a + b; |