From 2dd7c3bc4a6a49539e9847ec56a69cbf023e7e9b Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Fri, 1 Mar 2019 17:00:55 +0100 Subject: Fix matrix operator[] to allow M[j][k] and operator<< formatting --- include/mmmatrix.hpp | 59 ++++++++++++++++++++++++++----------------------- test/matrix_example.cpp | 4 ++++ 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 8f60312..6a31610 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -12,6 +12,7 @@ #pragma once #include +#include #include #include #include @@ -42,6 +43,9 @@ class mm::basic_matrix { public: using type = T; + template + friend class mm::basic_matrix; + static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; @@ -58,20 +62,17 @@ public: template basic_matrix(const basic_matrix& other); - // copy or move from 2D array - basic_matrix(const T (& values)[Rows][Cols]); - basic_matrix(T (&& values)[Rows][Cols]); - // access data T& at(std::size_t row, std::size_t col); const T& at(std::size_t row, std::size_t col) const; - auto&& operator[](std::size_t index); + // allows to access a matrix M at row j col k with M[j][k] + auto operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations - basic_matrix transposed() const; + virtual basic_matrix transposed() const; inline basic_matrix trd() const { return transposed(); } // bool is_invertible() const; @@ -79,7 +80,7 @@ public: /// downcast to square matrix - inline constexpr bool is_square() const { return (Rows == Cols); } + static inline constexpr bool is_square() { return (Rows == Cols); } inline constexpr square_matrix to_square() const { static_assert(is_square()); return static_cast>(*this); @@ -87,19 +88,23 @@ public: /// downcast to row_vector - inline constexpr bool is_row_vec() const { return (Cols == 1); } + static inline constexpr bool is_row_vec() { return (Cols == 1); } inline constexpr row_vec to_row_vec() const { static_assert(is_row_vec()); return static_cast>(*this); } /// downcast to col_vector - inline constexpr bool is_col_vec() const { return (Rows == 1); } + static inline constexpr bool is_col_vec() { return (Rows == 1); } inline constexpr col_vec to_col_vec() const { static_assert(is_col_vec()); return static_cast>(*this); } +protected: + template + basic_matrix(Iterator begin, Iterator end); + private: std::array data; }; @@ -150,13 +155,13 @@ mm::basic_matrix::basic_matrix( this->at(row, col) = other.at(row, col); } +/* protected construtor */ template -mm::basic_matrix::basic_matrix(const T (& values)[Rows][Cols]) - : data(values) {} - -template -mm::basic_matrix::basic_matrix(T (&& values)[Rows][Cols]) - : data(std::forward(values)) {} +template +mm::basic_matrix::basic_matrix(Iterator begin, Iterator end) { + assert(static_cast(std::distance(begin, end)) >= ((Rows * Cols))); + std::copy(begin, end, data.begin()); +} /* member functions */ @@ -178,17 +183,15 @@ const T& mm::basic_matrix::at(std::size_t row, std::size_t col) c } template -auto&& mm::basic_matrix::operator[](std::size_t index) { - if constexpr (is_row_vec()) { - static_assert(index < Rows); - return data[index]; - } else if constexpr (is_col_vec()) { - static_assert(index < Cols); - return data[index]; +auto mm::basic_matrix::operator[](std::size_t index) { + if constexpr (is_row_vec() || is_col_vec()) { + return data.at(index); + } else { + return row_vec( + data.begin() + (index * Cols), + data.begin() + ((index + 1) * Cols) + 1 + ); } - - // TODO: fix - // return row_vec(std::move(data[index])); } template @@ -283,14 +286,14 @@ mm::basic_matrix operator-( return a + (static_cast(-1) * b); } -template +template std::ostream& operator<<(std::ostream& os, const mm::basic_matrix& m) { for (unsigned row = 0; row < Rows; row++) { os << "[ "; for (unsigned col = 0; col < (Cols -1); col++) { - os << m.at(row, col) << ", "; + os << std::setw(NumW) << m.at(row, col) << ", "; } - os << m.at(row, (Cols -1)) << " ]\n"; + os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n"; } return os; diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 4cf1863..26aeede 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -12,6 +12,10 @@ int main(int argc, char *argv[]) { std::cout << "a = \n" << a; std::cout << "b = \n" << b; std::cout << "c = \n" << c; + + // access elements + std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl; + std::cout << "a[2][0] = " << a[2][0] << std::endl;; // basic operations std::cout << "a + b = \n" << a + b; -- cgit v1.2.1