From 69bfc145590498104e9ed960c38c4e63b1c5c952 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 23 Feb 2019 00:58:50 +0100 Subject: Add +, -, * operators to mm::basic_matrix --- include/mmmatrix.hpp | 160 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 147 insertions(+), 13 deletions(-) diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 7934d87..51233cb 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -3,12 +3,8 @@ * * This library is not intended to be _performant_, it does not contain * hand written SMID / SSE / AVX optimizations. It is instead an example - * of highly abstracted code, where matrices can contain any data type. - * - * As a challenge, the matrix data structure has been built on a container - * of static capacity. But if a dynamic base container is needed, the code - * should be easily modifiable to add further abstraction, by templating - * the container and possibly the allocator. + * of highly inefficient (but abstract!) code, where matrices can contain any + * data type. * * Naoki Pross * 2018 ~ 2019 @@ -21,6 +17,12 @@ namespace mm { template class basic_matrix; + + template + using row_vec = basic_matrix; + + template + using col_vec = basic_matrix; } template @@ -31,22 +33,50 @@ public: static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; - basic_matrix() {} + basic_matrix(const basic_matrix& other); + basic_matrix(basic_matrix&& other); + + template + basic_matrix(const basic_matrix& other); + + // access data + T& at(std::size_t row, std::size_t col); + + void swap_rows(std::size_t x, std::size_t y); + void swap_cols(std::size_t x, std::size_t y); + + // mathematical operations + basic_matrix transposed(); + inline basic_matrix t() { return transposed(); } - template - basic_matrix(const basic_matrix& other); + // bool is_invertible(); + // bool invert(); + // basic_matrix inverse(); - const T& at(std::size_t row, std::size_t col); + inline constexpr bool is_square() { + return (Rows == Cols); + } private: - std::array data; + T data[Rows][Cols] = {}; }; +template +mm::basic_matrix::basic_matrix(const mm::basic_matrix& other) { + for (int row = 0; row < Rows; row++) + for (int col = 0; col < Cols; col++) + data[row][col] = other.data[row][col]; +} + +template +mm::basic_matrix::basic_matrix(mm::basic_matrix&& other) { + data = other.data; +} template template -mm::basic_matrix::basic_matrix(const basic_matrix& other) { +mm::basic_matrix::basic_matrix(const mm::basic_matrix& other) { static_assert((ORows <= Rows), "cannot copy a taller matrix into a smaller one" ); @@ -55,5 +85,109 @@ mm::basic_matrix::basic_matrix(const basic_matrix +T& mm::basic_matrix::at(std::size_t row, std::size_t col) { + static_assert(row < Rows, "out of row bound"); + static_assert(col < Cols, "out of column bound"); + + return data[row][col]; +} + +template +void mm::basic_matrix::swap_rows(std::size_t x, std::size_t y) { + if (x == y) + return; + + for (int col = 0; col < Cols; col++) + std::swap(data[x][col], data[y][col]); +} + +template +void mm::basic_matrix::swap_cols(std::size_t x, std::size_t y) { + if (x == y) + return; + + for (int row = 0; row < rows; row++) + std::swap(data[row][x], data[row][y]); +} + +template +mm::basic_matrix mm::basic_matrix::transposed() { + mm::basic_matrix result; + + for (int row = 0; row < M; row++) + for (int col = 0; col < N; col++) + result.at(row, col) = at(col, row); + + return result; +} + + +/* operator overloading */ +template +mm::basic_matrix operator+( + const mm::basic_matrix& a, + const mm::basic_matrix& b +) { + mm::basic_matrix result; + + for (int row = 0; row < Rows; row++) + for (int col = 0; col < Cols; col++) + result.at(row, col) = a.at(row, col) + a.at(row, col); + + return result; +} + +template +mm::basic_matrix operator*( + const mm::basic_matrix& m, + const T& scalar +) { + mm::basic_matrix result; + for (int row = 0; row < Rows; row++) + for (int col = 0; col < Cols; col++) + result.at(row, col) = m.at(row, col) * scalar; + + return result; +} + +template +mm::basic_matrix operator*( + const T& scalar, + const mm::basic_matrix& m +) { + return m * scalar; +} + +template +mm::basic_matrix operator*( + const mm::basic_matrix& a, + const mm::basic_matrix& b +) { + mm::basic_matrix result; + + // TODO: use a more efficient algorithm + for (int row = 0; row < M; row++) + for (int col = 0; col < N; col++) + for (int k = 0; k < P; k++) + result.at(row, col) = a.at(row, k) * b.at(k, col); + + return result; +} + + +template +mm::basic_matrix operator-( + const mm::basic_matrix& a, + const mm::basic_matrix& b +) { + return a + static_cast(-1) * b; } -- cgit v1.2.1