From a9828ad7c2b73527bcb4b28e49193e4523f21ad5 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 23 Feb 2019 16:42:19 +0100 Subject: Change storage for matrix to std::array, update matrix_example --- include/mmmatrix.hpp | 60 ++++++++++++++++++++++++------------------------- test/matrix_example.cpp | 16 +++++++++++-- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index b973aba..8f60312 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace mm { template @@ -44,6 +45,8 @@ public: static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; + basic_matrix(); + // from initializer_list basic_matrix(std::initializer_list> l); @@ -98,42 +101,35 @@ public: } private: - T data[Rows * Cols] = {}; + std::array data; }; +template +mm::basic_matrix::basic_matrix() { + std::fill(data.begin(), data.end(), 0); +} template mm::basic_matrix::basic_matrix( std::initializer_list> l ) { assert(l.size() == Rows); + auto data_it = data.begin(); - auto row_it = l.begin(); - for (unsigned row = 0; (row < Rows) && (row_it != l.end()); row++) { - assert((*row_it).size() == Cols); - - auto col_it = (*row_it).begin(); - for (unsigned col = 0; (col < Cols) && (col_it != (*row_it).end()); col++) { - this->at(row, col) = *col_it; - ++col_it; - } - ++row_it; + for (auto&& row : l) { + data_it = std::copy(row.begin(), row.end(), data_it); } } template mm::basic_matrix::basic_matrix( const mm::basic_matrix& other -) { - std::memcpy(&data, &other.data, sizeof(data)); -} +) : data(other.data) {} template mm::basic_matrix::basic_matrix( mm::basic_matrix&& other -) { - data = other.data; -} +) : data(std::forward(other.data)) {} template template @@ -148,18 +144,19 @@ mm::basic_matrix::basic_matrix( "cannot copy a larger matrix into a smaller one" ); - std::memcpy(&data, &other.data, sizeof(data)); + std::fill(data.begin(), data.end(), 0); + for (unsigned row = 0; row < Rows; row++) + for (unsigned col = 0; col < Cols; col++) + this->at(row, col) = other.at(row, col); } template -mm::basic_matrix::basic_matrix(const T (& values)[Rows][Cols]) { - std::memcpy(&data, &values, sizeof(data)); -} +mm::basic_matrix::basic_matrix(const T (& values)[Rows][Cols]) + : data(values) {} template -mm::basic_matrix::basic_matrix(T (&& values)[Rows][Cols]) { - data = values; -} +mm::basic_matrix::basic_matrix(T (&& values)[Rows][Cols]) + : data(std::forward(values)) {} /* member functions */ @@ -218,7 +215,7 @@ mm::basic_matrix mm::basic_matrix::transposed() const { for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) - result.at(row, col) = this->at(row, col); + result.at(col, row) = this->at(row, col); return result; } @@ -234,7 +231,7 @@ mm::basic_matrix operator+( for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) - result.at(row, col) = a.at(row, col) + a.at(row, col); + result.at(row, col) = a.at(row, col) + b.at(row, col); return result; } @@ -260,17 +257,18 @@ mm::basic_matrix operator*( return m * scalar; } -template +template mm::basic_matrix operator*( - const mm::basic_matrix& a, - const mm::basic_matrix& b + const mm::basic_matrix& a, + const mm::basic_matrix& b ) { + static_assert(P1 == P2, "invalid matrix multiplication"); mm::basic_matrix result; // TODO: use a more efficient algorithm for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) - for (int k = 0; k < P; k++) + for (unsigned k = 0; k < P1; k++) result.at(row, col) = a.at(row, k) * b.at(k, col); return result; @@ -282,7 +280,7 @@ mm::basic_matrix operator-( const mm::basic_matrix& a, const mm::basic_matrix& b ) { - return a + static_cast(-1) * b; + return a + (static_cast(-1) * b); } template diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 3fb7c78..4cf1863 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -5,9 +5,21 @@ int main(int argc, char *argv[]) { std::cout << "MxN dimensional (int) matrices" << std::endl; - mm::matrix m {{1, 2}, {3, 4}, {5, 6}}; + mm::matrix a {{1, 2}, {3, 4}, {5, 6}}; + mm::matrix b {{4, 3}, {9, 1}, {2, 5}}; + mm::matrix c {{1, 2, 3, 4}, {5, 6, 7, 8}}; - std::cout << m; + std::cout << "a = \n" << a; + std::cout << "b = \n" << b; + std::cout << "c = \n" << c; + + // basic operations + std::cout << "a + b = \n" << a + b; + std::cout << "a - b = \n" << a - b; + std::cout << "a * c = \n" << a * c; + std::cout << "a * 2 = \n" << a * 2; + std::cout << "2 * a = \n" << 2 * a; + std::cout << "tr(a) = \n" << a.trd(); return 0; } -- cgit v1.2.1