/* mmmatrix.hpp * Part of Mathematical library built (ab)using Modern C++ 17 abstractions. * * This library is not intended to be _performant_, it does not contain * hand written SMID / SSE / AVX optimizations. It is instead an example * of highly inefficient (but abstract!) code, where matrices can contain any * data type. * * Naoki Pross * 2018 ~ 2019 */ #pragma once #include #include #include #include #include #include namespace mm { template class basic_matrix; /* specialization of basic_matrix for Rows == Cols */ template class square_interface; /* * Most general type of matrix iterator * IterType = Row, Column, Diagonal * Grid = constness of mm::basic_matrix */ template class vector_iterator; /* * Access methods interface */ template class access; /* specialisations */ template class matrix; // simple matrix format template class t_matrix; // transposed matrix format /* specialization of basic_matrx for Cols = 1 */ template class row_vec; /* specialization of basic_matrx for Rows = 1 */ template class col_vec; // transposed version of row_vec template class square_matrix; template class t_square_matrix; /* specialisation of a square_matrix for a sub-diagonal composed matrix */ template class diag_matrix; template class t_diag_matrix; } // TODO, short term solution #define MM_ROW_ITER 0 #define MM_COL_ITER 1 #define MM_DIAG_ITER 2 template class mm::vector_iterator { std::size_t index; // variable index Grid& M; const int position; // fixed index, negative too for diagonal iterator public: template friend class vector_iterator; vector_iterator(Grid& M, int position, std::size_t index = 0); mm::vector_iterator operator++() { vector_iterator it = *this; ++index; return it; } mm::vector_iterator operator--() { vector_iterator it = *this; --index; return it; } mm::vector_iterator& operator++(int) { ++index; return *this; } mm::vector_iterator& operator--(int) { --index; return *this; } bool operator==(const mm::vector_iterator& other) const { return index == other.index; } bool operator!=(const mm::vector_iterator& other) const { return index != other.index; } bool ok() const { if constexpr(IterType == MM_ROW_ITER) return index < Cols; else return index < Rows; } T& operator*() const; T& operator[](std::size_t); }; /* Row Iterators */ namespace mm { template using row_iterator = vector_iterator>; template using col_iterator = vector_iterator>; template using const_row_iterator = vector_iterator>>; template using const_col_iterator = vector_iterator>>; template using diag_iterator = vector_iterator>; template using const_diag_iterator = vector_iterator>>; } /* * Accessors */ template class mm::access { public: //access(mm::basic_matrix& ref) : M(ref) {} T& at(std::size_t row, std::size_t col); const T& at(std::size_t row, std::size_t col) const; auto operator[](std::size_t index); auto operator[](std::size_t index) const; //private: // mm::basic_matrix& M; protected: std::array data; }; /* * Square interface */ template class mm::square_interface { public: //square_interface(mm:basic_matrix& _M) : M(_M) {} T trace(); inline T tr() { return trace(); } // TODO, determinant /// in place inverse // TODO, det != 0 // TODO, use gauss jordan for invertible ones //void invert();, TODO, section algorithm //private: // mm:basic_matrix& M; // one information more than mm::matrix ! protected: std::array data; }; /* * Matrix class, no access methods */ template class mm::basic_matrix { public: using type = T; template friend class mm::basic_matrix; static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; basic_matrix(); // from initializer_list basic_matrix(std::initializer_list> l); // copyable and movable basic_matrix(const basic_matrix& other); basic_matrix(basic_matrix&& other); // copy from another matrix template basic_matrix(const basic_matrix& other); // access data, basic definition //virtual T& at(std::size_t row, std::size_t col); //virtual const T& at(std::size_t row, std::size_t col) const; // allows to access a matrix M at row j col k with M[j][k] //auto operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations //virtual basic_matrix transposed() const; //inline basic_matrix td() const { return transposed(); } /// downcast to square matrix static inline constexpr bool is_square() { return (Rows == Cols); } inline constexpr square_matrix to_square() const { static_assert(is_square()); return static_cast>(*this); } /// downcast to row_vector static inline constexpr bool is_row_vec() { return (Cols == 1); } inline constexpr row_vec to_row_vec() const { static_assert(is_row_vec()); return static_cast>(*this); } /// downcast to col_vector static inline constexpr bool is_col_vec() { return (Rows == 1); } inline constexpr col_vec to_col_vec() const { static_assert(is_col_vec()); return static_cast>(*this); } protected: template basic_matrix(ConstIterator begin, ConstIterator end); //private: std::array data; }; template mm::basic_matrix::basic_matrix() { std::fill(data.begin(), data.end(), 0); } template mm::basic_matrix::basic_matrix( std::initializer_list> l ) { assert(l.size() == Rows); auto data_it = data.begin(); for (auto&& row : l) { data_it = std::copy(row.begin(), row.end(), data_it); } } template mm::basic_matrix::basic_matrix( const mm::basic_matrix& other ) : data(other.data) {} template mm::basic_matrix::basic_matrix( mm::basic_matrix&& other ) : data(std::forward(other.data)) {} template template mm::basic_matrix::basic_matrix( const mm::basic_matrix& other ) { static_assert((ORows <= Rows), "cannot copy a taller matrix into a smaller one" ); static_assert((OCols <= Cols), "cannot copy a larger matrix into a smaller one" ); std::fill(data.begin(), data.end(), 0); for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) this->at(row, col) = other.at(row, col); } /* protected construtor */ template template mm::basic_matrix::basic_matrix( ConstIterator begin, ConstIterator end ) { assert(static_cast(std::distance(begin, end)) >= ((Rows * Cols))); std::copy(begin, end, data.begin()); } /* member functions */ /*template T& mm::basic_matrix::at(std::size_t row, std::size_t col) { assert(row < Rows); // "out of row bound" assert(col < Cols); // "out of column bound" return data[row * Cols + col]; } template const T& mm::basic_matrix::at(std::size_t row, std::size_t col) const { assert(row < Rows); // "out of row bound" assert(col < Cols); // "out of column bound" return data[row * Cols + col]; } template auto mm::basic_matrix::operator[](std::size_t index) { if constexpr (is_row_vec() || is_col_vec()) { return data.at(index); } else { return mm::row_iterator(*this, static_cast(index)); //return row_vec( // data.cbegin() + (index * Cols), // data.cbegin() + ((index + 1) * Cols) + 1 ); } }*/ template void mm::basic_matrix::swap_rows(std::size_t x, std::size_t y) { if (x == y) return; for (unsigned col = 0; col < Cols; col++) std::swap(this->at(x, col), this->at(y, col)); } template void mm::basic_matrix::swap_cols(std::size_t x, std::size_t y) { if (x == y) return; for (unsigned row = 0; row < rows; row++) std::swap(this->at(row, x), this->at(row, y)); } /*template mm::basic_matrix mm::basic_matrix::transposed() const { mm::basic_matrix result; for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) result.at(col, row) = this->at(row, col); return result; }*/ /* TODO, operator overloading */ /*template mm::basic_matrix operator+( const mm::basic_matrix& a, const mm::basic_matrix& b ) { mm::basic_matrix result; for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) result.at(row, col) = a.at(row, col) + b.at(row, col); return result; } template mm::basic_matrix operator*( const mm::basic_matrix& m, const T& scalar ) { mm::basic_matrix result; for (unsigned row = 0; row < Rows; row++) for (unsigned col = 0; col < Cols; col++) result.at(row, col) = m.at(row, col) * scalar; return result; } template mm::basic_matrix operator*( const T& scalar, const mm::basic_matrix& m ) { return m * scalar; } template mm::basic_matrix operator*( const mm::basic_matrix& a, const mm::basic_matrix& b ) { static_assert(P1 == P2, "invalid matrix multiplication"); mm::basic_matrix result; // TODO: use a more efficient algorithm for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) for (unsigned k = 0; k < P1; k++) result.at(row, col) = a.at(row, k) * b.at(k, col); return result; } template mm::basic_matrix operator-( const mm::basic_matrix& a, const mm::basic_matrix& b ) { return a + (static_cast(-1) * b); } template std::ostream& operator<<(std::ostream& os, const mm::basic_matrix& m) { for (unsigned row = 0; row < Rows; row++) { os << "[ "; for (unsigned col = 0; col < (Cols -1); col++) { os << std::setw(NumW) << m.at(row, col) << ", "; } os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n"; } return os; }*/ /* * derivated classes */ // simple format matrix template class mm::matrix : public mm::basic_matrix, virtual public mm::access { public: using mm::basic_matrix::basic_matrix; }; // transposed matrix template class mm::t_matrix : public mm::basic_matrix, virtual public mm::access { public: using mm::basic_matrix::basic_matrix; }; /* row vector specialization */ template class mm::row_vec : public mm::matrix { public: using mm::matrix::matrix; // TODO, begin, end }; /* column vector specialization */ template class mm::col_vec : public mm::t_matrix { public: using mm::t_matrix::t_matrix; // TODO, begin, end }; /* square matrix specialization */ template class mm::square_matrix : public mm::matrix , virtual public mm::square_interface { public: using mm::matrix::matrix; // get the identity of size N static inline constexpr square_matrix identity() { mm::square_matrix i; for (unsigned row = 0; row < N; row++) for (unsigned col = 0; col < N; col++) i.at(row, col) = (row == col) ? 1 : 0; return i; } }; template class mm::t_square_matrix : virtual public mm::t_matrix, virtual public mm::square_interface { public: using mm::t_matrix::t_matrix; }; /* * K-diagonal square matrix format * K is bounded between ]-N, N[ */ /*template class mm::diagonal_matrix : public mm::square_matrix { public: using mm::square_matrix::square_matrix; // TODO, redefine at, operator[] // TODO, matrix multiplication };*/ /*template void mm::square_matrix::transpose() { for (unsigned row = 0; row < N; row++) for (unsigned col = 0; col < row; col++) std::swap(this->at(row, col), this->at(col, row)); }*/ /* Iterators implementation */ template mm::vector_iterator::vector_iterator(Grid& _M, int pos, std::size_t i) : index(i), M(_M), position(pos) { if constexpr (IterType == MM_ROW_ITER) { assert(pos < Cols); } else if constexpr (IterType == MM_COL_ITER) { assert(pos < Rows); } else if constexpr (IterType == MM_DIAG_ITER) { assert(abs(pos) < Rows); } } template T& mm::vector_iterator::operator*() const { if constexpr (IterType == MM_ROW_ITER) return M.data[position * Cols + index]; else if constexpr (IterType == MM_COL_ITER) return M.data[index * Cols + position]; else if constexpr (IterType == MM_DIAG_ITER) return (position > 0) ? M.data[(index + position) * Cols + index] : M.data[index * Cols + (index - position)]; } template T& mm::vector_iterator::operator[](std::size_t i) { if constexpr (IterType == MM_ROW_ITER) return M.data[position * Cols + i]; else if constexpr (IterType == MM_COL_ITER) return M.data[i * Cols + position]; else if constexpr (IterType == MM_DIAG_ITER) return (position > 0) ? M.data[(i + position) * Cols + i] : M.data[i * Cols + (i - position)]; } /* * Accessors implementation */ template T& mm::access::at(std::size_t row, std::size_t col) { if constexpr (Regular) return data[row * Cols + col]; else return data[col * Cols + row]; // transpose } template const T& mm::access::at(std::size_t row, std::size_t col) const { if constexpr (Regular) return data[row * Cols + col]; else return data[col * Cols + row]; // transpose } template auto mm::access::operator[](std::size_t index) { if constexpr (this->is_row_vec() || this->is_col_vec()) return data.at(index); else if (Regular) return mm::row_iterator(*this, static_cast(index)); else return mm::col_iterator(*this, static_cast(index)); } template auto mm::access::operator[](std::size_t index) const { if constexpr (this->is_row_vec() || this->is_col_vec()) return data.at(index); else if (Regular) return mm::const_row_iterator(*this, static_cast(index)); else return mm::const_col_iterator(*this, static_cast(index)); } /* Square interface implementation */ template T mm::square_interface::trace() { T sum = 0; for (mm::diag_iterator it(*this, 0); it.ok(); ++it) sum += *it; return sum; }