diff options
-rw-r--r-- | include/mmmatrix.hpp | 101 |
1 files changed, 95 insertions, 6 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp index 7cb9359..d18f3f1 100644 --- a/include/mmmatrix.hpp +++ b/include/mmmatrix.hpp @@ -12,16 +12,27 @@ #pragma once #include <iostream> +#include <cstring> namespace mm { template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; + + template<typename T, std::size_t Rows, std::size_t Cols> + class matrix; + + template<typename T, std::size_t N> + class square_matrix; + + // template<typename T, std::size_t N> + // class diag_matrix; + template<typename T, std::size_t Rows> - using row_vec = basic_matrix<T, Rows, 1>; + class row_vec; template<typename T, std::size_t Cols> - using col_vec = basic_matrix<T, 1, Cols>; + class col_vec; } template<typename T, std::size_t Rows, std::size_t Cols> @@ -38,22 +49,45 @@ public: template<std::size_t ORows, std::size_t OCols> basic_matrix(const basic_matrix<T, ORows, OCols>& other); + basic_matrix(const T (& values)[Rows][Cols]); + basic_matrix(T (&& values)[Rows][Cols]); + // access data T& at(std::size_t row, std::size_t col); + auto&& operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations basic_matrix<T, Cols, Rows> transposed(); - inline basic_matrix<T, Cols, Rows> t() { return transposed(); } + inline basic_matrix<T, Cols, Rows> trd() { return transposed(); } // bool is_invertible(); // bool invert(); // basic_matrix<T, Rows, Cols> inverse(); - inline constexpr bool is_square() { - return (Rows == Cols); + + /// downcast to square matrix + inline constexpr bool is_square() { return (Rows == Cols); } + inline constexpr square_matrix<T, Rows> to_square() { + static_assert(is_square()); + return static_cast<square_matrix<T, Rows>>(*this); + } + + + /// downcast to row_vector + inline constexpr bool is_row_vec() { return (Cols == 1); } + inline constexpr row_vec<T, Rows> to_row_vec() { + static_assert(is_row_vec()); + return static_cast<row_vec<T, Rows>>(*this); + } + + /// downcast to col_vector + inline constexpr bool is_col_vec() { return (Rows == 1); } + inline constexpr col_vec<T, Cols> to_col_vec() { + static_assert(is_col_vec()); + return static_cast<col_vec<T, Cols>>(*this); } private: @@ -89,6 +123,16 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(const mm::basic_matrix<T, ORows, O data[row][col] = other.data[row][col]; } +template<typename T, std::size_t Rows, std::size_t Cols> +mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols]) { + std::memcpy(&data, &values, sizeof(data)); +} + +template<typename T, std::size_t Rows, std::size_t Cols> +mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols]) { + data = values; +} + /* member functions */ @@ -101,6 +145,16 @@ T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) { } template<typename T, std::size_t Rows, std::size_t Cols> +auto&& mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) { + if constexpr (is_row_vec()) + return data[0][index]; + else if constexpr (is_col_vec()) + return data[index][0]; + + return row_vec<T, Rows>(std::move(data[index])); +} + +template<typename T, std::size_t Rows, std::size_t Cols> void mm::basic_matrix<T, Rows, Cols>::swap_rows(std::size_t x, std::size_t y) { if (x == y) return; @@ -124,7 +178,7 @@ mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() { for (int row = 0; row < M; row++) for (int col = 0; col < N; col++) - result.at(row, col) = at(col, row); + result[row][col] = this[col][row]; return result; } @@ -203,3 +257,38 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols> return os; } + + +/* square matrix specializaiton */ + +template<typename T, std::size_t N> +class mm::square_matrix : public mm::basic_matrix<T, N, N> { +public: + /// in place transpose + void transpose(); + inline void tr() { transpose(); } + + /// in place inverse + void invert(); +}; + + +template<typename T, std::size_t N> +void mm::square_matrix<T, N>::transpose() { + for (int row = 0; row < N; row++) + for (int col = 0; col < row; col++) + std::swap(this->at(row, col), this->at(col, row)); +} + + +/* row vector specialization */ +template<typename T, std::size_t Rows> +class mm::row_vec : public mm::basic_matrix<T, Rows, 1> { +public: +}; + +/* column vector specialization */ +template<typename T, std::size_t Cols> +class mm::col_vec : public mm::basic_matrix<T, 1, Cols> { +public: +}; |