summaryrefslogtreecommitdiffstats
path: root/include/mm/mmmatrix.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r--include/mm/mmmatrix.hpp405
1 files changed, 396 insertions, 9 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index 5285429..c11b626 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -19,9 +19,14 @@
#include <array>
namespace mm {
+
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
+ // TODO, not sure it's a good idea
+ //template<typename T, std::size_t Rows, std::size_t Cols>
+ //class transposed_matrix;
+
/* specialization of basic_matrx for Cols = 1 */
template<typename T, std::size_t Rows>
class row_vec;
@@ -38,10 +43,251 @@ namespace mm {
template<typename T, std::size_t N>
class square_matrix;
- // template<typename T, std::size_t N>
- // class diag_matrix;
+ template<typename T, std::size_t N>
+ class diagonal_matrix;
+
+ /*
+ * Iterators
+ */
+
+ template<typename T, std::size_t Rows, std::size_t Cols>
+ class vector_iterator;
+
+ template<typename T, std::size_t N>
+ class diag_iterator;
+
+ template<typename T, std::size_t Rows, std::size_t Cols>
+ class const_vector_iterator;
+
+ template<typename T, std::size_t N>
+ class const_diag_iterator;
}
+/* Non-const Iterators */
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+class mm::vector_iterator
+{
+ std::size_t index; // variable index
+
+ mm::basic_matrix<T, Rows, Cols>& M;
+
+ const std::size_t position; // fixed index
+ const bool direction; // true = row, false = column
+
+public:
+ template<typename U, std::size_t ORows, std::size_t OCols>
+ friend class vector_iterator;
+
+ vector_iterator(mm::basic_matrix<T, Rows, Cols>& M, std::size_t position, bool direction);
+
+ mm::vector_iterator<T, Rows, Cols> operator++()
+ {
+ vector_iterator<T, Rows, Cols> it = *this;
+ ++index;
+ return it;
+ }
+
+ mm::vector_iterator<T, Rows, Cols> operator--()
+ {
+ vector_iterator<T, Rows, Cols> it = *this;
+ --index;
+ return it;
+ }
+
+ mm::vector_iterator<T, Rows, Cols>& operator++(int)
+ {
+ ++index;
+ return *this;
+ }
+
+ mm::vector_iterator<T, Rows, Cols>& operator--(int)
+ {
+ --index;
+ return *this;
+ }
+
+ bool operator==(const mm::vector_iterator<T, Rows, Cols>& other) const
+ {
+ return index == other.index;
+ }
+
+ bool operator=!(const mm::vector_iterator<T, Rows, Cols>& other) const
+ {
+ return index != other.index;
+ }
+
+ T& operator*() const;
+ T& operator[](std::size_t);
+};
+
+template<typename T, std::size_t N>
+class diag_iterator
+{
+ std::size_t index; // variable index
+
+ mm::square_matrix<T, N>& M;
+
+ const int position; // fixed diagonal index
+
+public:
+ template<typename U, std::size_t ON>
+ friend class diag_iterator;
+
+ diag_iterator(mm::square_matrix<T, N>& M, std::size_t position, bool direction);
+
+ mm::diag_iterator<T, N> operator++()
+ {
+ diag_iterator<T, N> it = *this;
+ ++index;
+ return it;
+ }
+
+ mm::diag_iterator<T, N> operator--()
+ {
+ diag_iterator<T, N> it = *this;
+ --index;
+ return it;
+ }
+
+ mm::diag_iterator<T, N>& operator++(int)
+ {
+ ++index;
+ return *this;
+ }
+
+ mm::diag_iterator<T, N>& operator--(int)
+ {
+ --index;
+ return *this;
+ }
+
+ bool operator==(const mm::diag_iterator<T, N>& other) const
+ {
+ return index == other.index;
+ }
+
+ bool operator=!(const mm::diag_iterator<T, N>& other) const
+ {
+ return index != other.index;
+ }
+
+ T& operator*() const;
+};
+
+/* Const Iterators */
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+class mm::const_vector_iterator
+{
+ std::size_t index; // variable index
+
+ const mm::basic_matrix<T, Rows, Cols>& M;
+
+ const std::size_t position; // fixed index
+ const bool direction; // true = row, false = column
+
+public:
+ const_vector_iterator(mm::basic_matrix<T, Rows, Cols>& M, std::size_t position, bool direction);
+
+ mm::const_vector_iterator<T, Rows, Cols> operator++()
+ {
+ vector_iterator<T, Rows, Cols> it = *this;
+ ++index;
+ return it;
+ }
+
+ mm::const_vector_iterator<T, Rows, Cols> operator--()
+ {
+ vector_iterator<T, Rows, Cols> it = *this;
+ --index;
+ return it;
+ }
+
+ mm::const_vector_iterator<T, Rows, Cols>& operator++(int)
+ {
+ ++index;
+ return *this;
+ }
+
+ mm::const_vector_iterator<T, Rows, Cols>& operator--(int)
+ {
+ --index;
+ return *this;
+ }
+
+ bool operator==(const mm::const_vector_iterator<T, Rows, Cols>& other) const
+ {
+ return index == other;
+ }
+
+ bool operator=!(const mm::const_vector_iterator<T, Rows, Cols>& other) const
+ {
+ return index != other;
+ }
+
+ const T& operator*() const;
+ const T& operator[](std::size_t) const;
+};
+
+template<typename T>
+class const_diag_iterator
+{
+ std::size_t index; // variable index
+
+ const mm::square_matrix<T, N>& M;
+
+ const int position; // fixed diagonal index
+
+public:
+ template<typename U, std::size_t ON>
+ friend class const_diag_iterator;
+
+ const_diag_iterator(const mm::square_matrix<T, N>& M, std::size_t position, bool direction);
+
+ mm::const_diag_iterator<T, N> operator++()
+ {
+ const_diag_iterator<T, N> it = *this;
+ ++index;
+ return it;
+ }
+
+ mm::const_diag_iterator<T, N> operator--()
+ {
+ const_diag_iterator<T, N> it = *this;
+ --index;
+ return it;
+ }
+
+ mm::const_diag_iterator<T, N>& operator++(int)
+ {
+ ++index;
+ return *this;
+ }
+
+ mm::const_diag_iterator<T, N>& operator--(int)
+ {
+ --index;
+ return *this;
+ }
+
+ bool operator==(const mm::const_diag_iterator<T, N>& other) const
+ {
+ return index == other.index;
+ }
+
+ bool operator=!(const mm::const_diag_iterator<T, N>& other) const
+ {
+ return index != other.index;
+ }
+
+ const T& operator*() const;
+};
+
+/*
+ * Matrix class
+ */
+
template<typename T, std::size_t Rows, std::size_t Cols>
class mm::basic_matrix {
public:
@@ -50,6 +296,9 @@ public:
template<typename U, std::size_t ORows, std::size_t OCols>
friend class mm::basic_matrix;
+ template<typename U, std::size_t ORows, std::size_t OCols>
+ friend class mm::vector_iterator;
+
static constexpr std::size_t rows = Rows;
static constexpr std::size_t cols = Cols;
@@ -67,21 +316,20 @@ public:
basic_matrix(const basic_matrix<T, ORows, OCols>& other);
// access data
- T& at(std::size_t row, std::size_t col);
- const T& at(std::size_t row, std::size_t col) const;
+ virtual T& at(std::size_t row, std::size_t col);
+ virtual const T& at(std::size_t row, std::size_t col) const;
+
// allows to access a matrix M at row j col k with M[j][k]
- auto operator[](std::size_t index);
+ virtual auto operator[](std::size_t index);
void swap_rows(std::size_t x, std::size_t y);
void swap_cols(std::size_t x, std::size_t y);
// mathematical operations
+ // TODO, simply switch iteration mode
virtual basic_matrix<T, Cols, Rows> transposed() const;
inline basic_matrix<T, Cols, Rows> td() const { return transposed(); }
- // bool is_invertible() const;
- // basic_matrix<T, Rows, Cols> inverse() const;
-
/// downcast to square matrix
static inline constexpr bool is_square() { return (Rows == Cols); }
@@ -307,6 +555,9 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>
+/*
+ * derivated classes
+ */
/* row vector specialization */
template<typename T, std::size_t Rows>
@@ -329,7 +580,35 @@ public:
using mm::basic_matrix<T, Rows, Cols>::basic_matrix;
};
-/* square matrix specializaiton */
+/*
+ * transposed matrix format
+ * TODO: write this class, or put a bool flag into the original one
+ */
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+class mm::transposed_matrix : public mm::basic_matrix<T, Rows, Cols>
+{
+public:
+ using mm::basic_matrix<T, Rows, Cols>::basic_matrix;
+
+ virtual T& at(std::size_t row, std::size_t col) override
+ {
+ return mm::basic_matrix<T, Rows, Cols>::at(col, row);
+ }
+
+ virtual const T& at(std::size_t row, std::size_t col) const override
+ {
+ return mm::basic_matrix<T, Rows, Cols>::at(col, row);
+ }
+
+ // allows to access a matrix M at row j col k with M[j][k]
+ virtual auto operator[](std::size_t index) override
+ {
+ // TODO, return other direction iterator
+ }
+}
+
+/* square matrix specialization */
template<typename T, std::size_t N>
class mm::square_matrix : public mm::basic_matrix<T, N, N> {
public:
@@ -343,8 +622,20 @@ public:
inline T tr() { return trace(); }
/// in place inverse
+ // TODO, det != 0
+ // TODO, use gauss jordan for invertible ones
void invert();
+
+ // TODO, downcast to K-diagonal, user defined cast
+ template<int K>
+ operator mm::diagonal_matrix<T, N, K>() const
+ {
+ // it's always possible to do it bidirectionally,
+ // without loosing information
+ return dynamic_cast<mm::diagonal_matrix<T, N, K>>(*this);
+ }
+
// get the identity of size N
static inline constexpr square_matrix<T, N> identity() {
square_matrix<T, N> i;
@@ -356,6 +647,20 @@ public:
}
};
+/*
+ * K-diagonal square matrix format
+ * K is bounded between ]-N, N[
+ */
+
+template<typename T, std::size_t N, int K>
+class mm::diagonal_matrix : public mm::square_matrix
+{
+public:
+ using mm::square_matrix<T, N>::square_matrix;
+
+ // TODO, redefine at, operator[]
+ // TODO, matrix multiplication
+};
template<typename T, std::size_t N>
void mm::square_matrix<T, N>::transpose() {
@@ -372,3 +677,85 @@ T mm::square_matrix<T, N>::trace() {
return sum;
}
+
+/* Iterators implementations */
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::vector_iterator<T, Rows, Cols>::vector_iterator(mm::basic_matrix<T, Rows, Cols>& _M, std::size_t pos, bool dir)
+ index(0), M(_M), position(pos), direction(dir)
+{
+ assert((dir && pos < Cols) || (!dir && pos < Rows))
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+T& mm::vector_iterator<T, Rows, Cols>::operator*() const
+{
+ return (direction) ?
+ M.data[position * Cols + index] :
+ M.data[index * Cols + position];
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+T& mm::vector_iterator<T, Rows, Cols>::operator[](std::size_t i)
+{
+ return (direction) ?
+ M.data[position * Cols + i] :
+ M.data[i * Cols + position];
+}
+
+template<typename T, std::size_t N>
+mm::diag_iterator<T, N>::diag_iterator(mm::square_matrix<T, N>& _M, int pos)
+ index(0), M(_M), position(pos)
+{
+ assert(abs(pos) < N) // pos bounded between ]-N, N[
+}
+
+template<typename T, std::size_t N>
+T& mm::diag_iterator<T, N>::operator*() const
+{
+ return (k > 0) ?
+ M.data[(index + position) * Cols + index] :
+ M.data[index * Cols + (index - position)];
+}
+
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::const_vector_iterator<T, Rows, Cols>::const_vector_iterator(const mm::basic_matrix<T, Rows, Cols>& _M, std::size_t pos, bool dir)
+ index(0), M(_M), position(pos), direction(dir)
+{
+ assert((dir && pos < Cols) || (!dir && pos < Rows))
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+const T& mm::const_vector_iterator<T, Rows, Cols>::operator*() const
+{
+ return (direction) ?
+ M.data[position * Cols + index] :
+ M.data[index * Cols + position];
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+const T& mm::const_vector_iterator<T, Rows, Cols>::operator[](std::size_t i) const
+{
+ return (direction) ?
+ M.data[position * Cols + i] :
+ M.data[i * Cols + position];
+}
+
+template<typename T, std::size_t N>
+mm::const_diag_iterator<T, N>::const_diag_iterator(const mm::square_matrix<T, N>& _M, int pos)
+ index(0), M(_M), position(pos)
+{
+ assert(abs(pos) < N) // pos bounded between ]-N, N[
+}
+
+template<typename T, std::size_t N>
+T& mm::const_diag_iterator<T, N>::operator*() const
+{
+ return (k > 0) ?
+ M.data[(index + position) * Cols + index] :
+ M.data[index * Cols + (index - position)];
+}
+
+
+