diff options
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r-- | include/mm/mmmatrix.hpp | 405 |
1 files changed, 396 insertions, 9 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 5285429..c11b626 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -19,9 +19,14 @@ #include <array> namespace mm { + template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; + // TODO, not sure it's a good idea + //template<typename T, std::size_t Rows, std::size_t Cols> + //class transposed_matrix; + /* specialization of basic_matrx for Cols = 1 */ template<typename T, std::size_t Rows> class row_vec; @@ -38,10 +43,251 @@ namespace mm { template<typename T, std::size_t N> class square_matrix; - // template<typename T, std::size_t N> - // class diag_matrix; + template<typename T, std::size_t N> + class diagonal_matrix; + + /* + * Iterators + */ + + template<typename T, std::size_t Rows, std::size_t Cols> + class vector_iterator; + + template<typename T, std::size_t N> + class diag_iterator; + + template<typename T, std::size_t Rows, std::size_t Cols> + class const_vector_iterator; + + template<typename T, std::size_t N> + class const_diag_iterator; } +/* Non-const Iterators */ + +template<typename T, std::size_t Rows, std::size_t Cols> +class mm::vector_iterator +{ + std::size_t index; // variable index + + mm::basic_matrix<T, Rows, Cols>& M; + + const std::size_t position; // fixed index + const bool direction; // true = row, false = column + +public: + template<typename U, std::size_t ORows, std::size_t OCols> + friend class vector_iterator; + + vector_iterator(mm::basic_matrix<T, Rows, Cols>& M, std::size_t position, bool direction); + + mm::vector_iterator<T, Rows, Cols> operator++() + { + vector_iterator<T, Rows, Cols> it = *this; + ++index; + return it; + } + + mm::vector_iterator<T, Rows, Cols> operator--() + { + vector_iterator<T, Rows, Cols> it = *this; + --index; + return it; + } + + mm::vector_iterator<T, Rows, Cols>& operator++(int) + { + ++index; + return *this; + } + + mm::vector_iterator<T, Rows, Cols>& operator--(int) + { + --index; + return *this; + } + + bool operator==(const mm::vector_iterator<T, Rows, Cols>& other) const + { + return index == other.index; + } + + bool operator=!(const mm::vector_iterator<T, Rows, Cols>& other) const + { + return index != other.index; + } + + T& operator*() const; + T& operator[](std::size_t); +}; + +template<typename T, std::size_t N> +class diag_iterator +{ + std::size_t index; // variable index + + mm::square_matrix<T, N>& M; + + const int position; // fixed diagonal index + +public: + template<typename U, std::size_t ON> + friend class diag_iterator; + + diag_iterator(mm::square_matrix<T, N>& M, std::size_t position, bool direction); + + mm::diag_iterator<T, N> operator++() + { + diag_iterator<T, N> it = *this; + ++index; + return it; + } + + mm::diag_iterator<T, N> operator--() + { + diag_iterator<T, N> it = *this; + --index; + return it; + } + + mm::diag_iterator<T, N>& operator++(int) + { + ++index; + return *this; + } + + mm::diag_iterator<T, N>& operator--(int) + { + --index; + return *this; + } + + bool operator==(const mm::diag_iterator<T, N>& other) const + { + return index == other.index; + } + + bool operator=!(const mm::diag_iterator<T, N>& other) const + { + return index != other.index; + } + + T& operator*() const; +}; + +/* Const Iterators */ + +template<typename T, std::size_t Rows, std::size_t Cols> +class mm::const_vector_iterator +{ + std::size_t index; // variable index + + const mm::basic_matrix<T, Rows, Cols>& M; + + const std::size_t position; // fixed index + const bool direction; // true = row, false = column + +public: + const_vector_iterator(mm::basic_matrix<T, Rows, Cols>& M, std::size_t position, bool direction); + + mm::const_vector_iterator<T, Rows, Cols> operator++() + { + vector_iterator<T, Rows, Cols> it = *this; + ++index; + return it; + } + + mm::const_vector_iterator<T, Rows, Cols> operator--() + { + vector_iterator<T, Rows, Cols> it = *this; + --index; + return it; + } + + mm::const_vector_iterator<T, Rows, Cols>& operator++(int) + { + ++index; + return *this; + } + + mm::const_vector_iterator<T, Rows, Cols>& operator--(int) + { + --index; + return *this; + } + + bool operator==(const mm::const_vector_iterator<T, Rows, Cols>& other) const + { + return index == other; + } + + bool operator=!(const mm::const_vector_iterator<T, Rows, Cols>& other) const + { + return index != other; + } + + const T& operator*() const; + const T& operator[](std::size_t) const; +}; + +template<typename T> +class const_diag_iterator +{ + std::size_t index; // variable index + + const mm::square_matrix<T, N>& M; + + const int position; // fixed diagonal index + +public: + template<typename U, std::size_t ON> + friend class const_diag_iterator; + + const_diag_iterator(const mm::square_matrix<T, N>& M, std::size_t position, bool direction); + + mm::const_diag_iterator<T, N> operator++() + { + const_diag_iterator<T, N> it = *this; + ++index; + return it; + } + + mm::const_diag_iterator<T, N> operator--() + { + const_diag_iterator<T, N> it = *this; + --index; + return it; + } + + mm::const_diag_iterator<T, N>& operator++(int) + { + ++index; + return *this; + } + + mm::const_diag_iterator<T, N>& operator--(int) + { + --index; + return *this; + } + + bool operator==(const mm::const_diag_iterator<T, N>& other) const + { + return index == other.index; + } + + bool operator=!(const mm::const_diag_iterator<T, N>& other) const + { + return index != other.index; + } + + const T& operator*() const; +}; + +/* + * Matrix class + */ + template<typename T, std::size_t Rows, std::size_t Cols> class mm::basic_matrix { public: @@ -50,6 +296,9 @@ public: template<typename U, std::size_t ORows, std::size_t OCols> friend class mm::basic_matrix; + template<typename U, std::size_t ORows, std::size_t OCols> + friend class mm::vector_iterator; + static constexpr std::size_t rows = Rows; static constexpr std::size_t cols = Cols; @@ -67,21 +316,20 @@ public: basic_matrix(const basic_matrix<T, ORows, OCols>& other); // access data - T& at(std::size_t row, std::size_t col); - const T& at(std::size_t row, std::size_t col) const; + virtual T& at(std::size_t row, std::size_t col); + virtual const T& at(std::size_t row, std::size_t col) const; + // allows to access a matrix M at row j col k with M[j][k] - auto operator[](std::size_t index); + virtual auto operator[](std::size_t index); void swap_rows(std::size_t x, std::size_t y); void swap_cols(std::size_t x, std::size_t y); // mathematical operations + // TODO, simply switch iteration mode virtual basic_matrix<T, Cols, Rows> transposed() const; inline basic_matrix<T, Cols, Rows> td() const { return transposed(); } - // bool is_invertible() const; - // basic_matrix<T, Rows, Cols> inverse() const; - /// downcast to square matrix static inline constexpr bool is_square() { return (Rows == Cols); } @@ -307,6 +555,9 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols> +/* + * derivated classes + */ /* row vector specialization */ template<typename T, std::size_t Rows> @@ -329,7 +580,35 @@ public: using mm::basic_matrix<T, Rows, Cols>::basic_matrix; }; -/* square matrix specializaiton */ +/* + * transposed matrix format + * TODO: write this class, or put a bool flag into the original one + */ + +template<typename T, std::size_t Rows, std::size_t Cols> +class mm::transposed_matrix : public mm::basic_matrix<T, Rows, Cols> +{ +public: + using mm::basic_matrix<T, Rows, Cols>::basic_matrix; + + virtual T& at(std::size_t row, std::size_t col) override + { + return mm::basic_matrix<T, Rows, Cols>::at(col, row); + } + + virtual const T& at(std::size_t row, std::size_t col) const override + { + return mm::basic_matrix<T, Rows, Cols>::at(col, row); + } + + // allows to access a matrix M at row j col k with M[j][k] + virtual auto operator[](std::size_t index) override + { + // TODO, return other direction iterator + } +} + +/* square matrix specialization */ template<typename T, std::size_t N> class mm::square_matrix : public mm::basic_matrix<T, N, N> { public: @@ -343,8 +622,20 @@ public: inline T tr() { return trace(); } /// in place inverse + // TODO, det != 0 + // TODO, use gauss jordan for invertible ones void invert(); + + // TODO, downcast to K-diagonal, user defined cast + template<int K> + operator mm::diagonal_matrix<T, N, K>() const + { + // it's always possible to do it bidirectionally, + // without loosing information + return dynamic_cast<mm::diagonal_matrix<T, N, K>>(*this); + } + // get the identity of size N static inline constexpr square_matrix<T, N> identity() { square_matrix<T, N> i; @@ -356,6 +647,20 @@ public: } }; +/* + * K-diagonal square matrix format + * K is bounded between ]-N, N[ + */ + +template<typename T, std::size_t N, int K> +class mm::diagonal_matrix : public mm::square_matrix +{ +public: + using mm::square_matrix<T, N>::square_matrix; + + // TODO, redefine at, operator[] + // TODO, matrix multiplication +}; template<typename T, std::size_t N> void mm::square_matrix<T, N>::transpose() { @@ -372,3 +677,85 @@ T mm::square_matrix<T, N>::trace() { return sum; } + +/* Iterators implementations */ + +template<typename T, std::size_t Rows, std::size_t Cols> +mm::vector_iterator<T, Rows, Cols>::vector_iterator(mm::basic_matrix<T, Rows, Cols>& _M, std::size_t pos, bool dir) + index(0), M(_M), position(pos), direction(dir) +{ + assert((dir && pos < Cols) || (!dir && pos < Rows)) +} + +template<typename T, std::size_t Rows, std::size_t Cols> +T& mm::vector_iterator<T, Rows, Cols>::operator*() const +{ + return (direction) ? + M.data[position * Cols + index] : + M.data[index * Cols + position]; +} + +template<typename T, std::size_t Rows, std::size_t Cols> +T& mm::vector_iterator<T, Rows, Cols>::operator[](std::size_t i) +{ + return (direction) ? + M.data[position * Cols + i] : + M.data[i * Cols + position]; +} + +template<typename T, std::size_t N> +mm::diag_iterator<T, N>::diag_iterator(mm::square_matrix<T, N>& _M, int pos) + index(0), M(_M), position(pos) +{ + assert(abs(pos) < N) // pos bounded between ]-N, N[ +} + +template<typename T, std::size_t N> +T& mm::diag_iterator<T, N>::operator*() const +{ + return (k > 0) ? + M.data[(index + position) * Cols + index] : + M.data[index * Cols + (index - position)]; +} + + +template<typename T, std::size_t Rows, std::size_t Cols> +mm::const_vector_iterator<T, Rows, Cols>::const_vector_iterator(const mm::basic_matrix<T, Rows, Cols>& _M, std::size_t pos, bool dir) + index(0), M(_M), position(pos), direction(dir) +{ + assert((dir && pos < Cols) || (!dir && pos < Rows)) +} + +template<typename T, std::size_t Rows, std::size_t Cols> +const T& mm::const_vector_iterator<T, Rows, Cols>::operator*() const +{ + return (direction) ? + M.data[position * Cols + index] : + M.data[index * Cols + position]; +} + +template<typename T, std::size_t Rows, std::size_t Cols> +const T& mm::const_vector_iterator<T, Rows, Cols>::operator[](std::size_t i) const +{ + return (direction) ? + M.data[position * Cols + i] : + M.data[i * Cols + position]; +} + +template<typename T, std::size_t N> +mm::const_diag_iterator<T, N>::const_diag_iterator(const mm::square_matrix<T, N>& _M, int pos) + index(0), M(_M), position(pos) +{ + assert(abs(pos) < N) // pos bounded between ]-N, N[ +} + +template<typename T, std::size_t N> +T& mm::const_diag_iterator<T, N>::operator*() const +{ + return (k > 0) ? + M.data[(index + position) * Cols + index] : + M.data[index * Cols + (index - position)]; +} + + + |