/* mmmatrix.hpp * Part of Mathematical library built (ab)using Modern C++ 17 abstractions. * * This library is not intended to be _performant_, it does not contain * hand written SMID / SSE / AVX optimizations. It is instead an example * of highly inefficient (but abstract!) code, where matrices can contain any * data type. * * Naoki Pross * 2018 ~ 2019 */ #pragma once #include #include #include #include namespace mm { template class basic_matrix; template class matrix; template class square_matrix; // template // class diag_matrix; template class row_vec; template class col_vec; } template class mm::basic_matrix { public: using type = T; static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; // from initializer_list basic_matrix(std::initializer_list> l); // copyable and movable basic_matrix(const basic_matrix& other); basic_matrix(basic_matrix&& other); // copy from another matrix template basic_matrix(const basic_matrix& other); // copy or move from 2D array basic_matrix(const T (& values)[Rows][Cols]); basic_matrix(T (&& values)[Rows][Cols]); // access data T& at(std::size_t row, std::size_t col); const T& at(std::size_t row, std::size_t col) const; auto&& operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations basic_matrix transposed() const; inline basic_matrix trd() const { return transposed(); } // bool is_invertible() const; // basic_matrix inverse() const; /// downcast to square matrix inline constexpr bool is_square() const { return (Rows == Cols); } inline constexpr square_matrix to_square() const { static_assert(is_square()); return static_cast>(*this); } /// downcast to row_vector inline constexpr bool is_row_vec() const { return (Cols == 1); } inline constexpr row_vec to_row_vec() const { static_assert(is_row_vec()); return static_cast>(*this); } /// downcast to col_vector inline constexpr bool is_col_vec() const { return (Rows == 1); } inline constexpr col_vec to_col_vec() const { static_assert(is_col_vec()); return static_cast>(*this); } private: T data[Rows * Cols] = {}; }; template mm::basic_matrix::basic_matrix( std::initializer_list> l ) { assert(l.size() == Rows); auto row_it = l.begin(); for (unsigned row = 0; (row < Rows) && (row_it != l.end()); row++) { assert((*row_it).size() == Cols); auto col_it = (*row_it).begin(); for (unsigned col = 0; (col < Cols) && (col_it != (*row_it).end()); col++) { this->at(row, col) = *col_it; ++col_it; } ++row_it; } } template mm::basic_matrix::basic_matrix( const mm::basic_matrix& other ) { std::memcpy(&data, &other.data, sizeof(data)); } template mm::basic_matrix::basic_matrix( mm::basic_matrix&& other ) { data = other.data; } template template mm::basic_matrix::basic_matrix( const mm::basic_matrix& other ) { static_assert((ORows <= Rows), "cannot copy a taller matrix into a smaller one" ); static_assert((OCols <= Cols), "cannot copy a larger matrix into a smaller one" ); std::memcpy(&data, &other.data, sizeof(data)); } template mm::basic_matrix::basic_matrix(const T (& values)[Rows][Cols]) { std::memcpy(&data, &values, sizeof(data)); } template mm::basic_matrix::basic_matrix(T (&& values)[Rows][Cols]) { data = values; } /* member functions */ template T& mm::basic_matrix::at(std::size_t row, std::size_t col) { assert(row < Rows); // "out of row bound" assert(col < Cols); // "out of column bound" return data[row * Cols + col]; } template const T& mm::basic_matrix::at(std::size_t row, std::size_t col) const { assert(row < Rows); // "out of row bound" assert(col < Cols); // "out of column bound" return data[row * Cols + col]; } template auto&& mm::basic_matrix::operator[](std::size_t index) { if constexpr (is_row_vec()) { static_assert(index < Rows); return data[index]; } else if constexpr (is_col_vec()) { static_assert(index < Cols); return data[index]; } // TODO: fix // return row_vec(std::move(data[index])); } template void mm::basic_matrix::swap_rows(std::size_t x, std::size_t y) { if (x == y) return; for (unsigned col = 0; col < Cols; col++) std::swap(this->at(x, col), this->at(y, col)); } template void mm::basic_matrix::swap_cols(std::size_t x, std::size_t y) { if (x == y) return; for (unsigned row = 0; row < rows; row++) std::swap(this->at(row, x), this->at(row, y)); } template mm::basic_matrix mm::basic_matrix::transposed() const { mm::basic_matrix result; for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) result.at(row, col) = this->at(row, col); return result; } /* operator overloading */ template mm::basic_matrix operator+( const mm::basic_matrix& a, const mm::basic_matrix& b ) { mm::basic_matrix result; for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) result.at(row, col) = a.at(row, col) + a.at(row, col); return result; } template mm::basic_matrix operator*( const mm::basic_matrix& m, const T& scalar ) { mm::basic_matrix result; for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) result.at(row, col) = m.at(row, col) * scalar; return result; } template mm::basic_matrix operator*( const T& scalar, const mm::basic_matrix& m ) { return m * scalar; } template mm::basic_matrix operator*( const mm::basic_matrix& a, const mm::basic_matrix& b ) { mm::basic_matrix result; // TODO: use a more efficient algorithm for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) for (int k = 0; k < P; k++) result.at(row, col) = a.at(row, k) * b.at(k, col); return result; } template mm::basic_matrix operator-( const mm::basic_matrix& a, const mm::basic_matrix& b ) { return a + static_cast(-1) * b; } template std::ostream& operator<<(std::ostream& os, const mm::basic_matrix& m) { for (unsigned row = 0; row < Rows; row++) { os << "[ "; for (unsigned col = 0; col < (Cols -1); col++) { os << m.at(row, col) << ", "; } os << m.at(row, (Cols -1)) << " ]\n"; } return os; } /* square matrix specializaiton */ template class mm::square_matrix : public mm::basic_matrix { public: /// in place transpose void transpose(); inline void tr() { transpose(); } /// in place inverse void invert(); }; template void mm::square_matrix::transpose() { for (unsigned row = 0; row < N; row++) for (unsigned col = 0; col < row; col++) std::swap(this->at(row, col), this->at(col, row)); } /* row vector specialization */ template class mm::row_vec : public mm::basic_matrix { public: using mm::basic_matrix::basic_matrix; }; /* column vector specialization */ template class mm::col_vec : public mm::basic_matrix { public: using mm::basic_matrix::basic_matrix; }; template class mm::matrix : public mm::basic_matrix { public: using mm::basic_matrix::basic_matrix; };