From cbcb6aeb1510ee1bd2d67d9d0b285df67f7088b5 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 5 Oct 2019 17:59:32 +0200 Subject: New matrix data model (breaks everything) --- include/mm/mmiterator.hpp | 248 ------------------- include/mm/mmmatrix.hpp | 595 +++++----------------------------------------- include/mm/view.hpp | 66 +++++ test/matrix_example.cpp | 22 -- 4 files changed, 132 insertions(+), 799 deletions(-) delete mode 100644 include/mm/mmiterator.hpp create mode 100644 include/mm/view.hpp diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp deleted file mode 100644 index 4efe474..0000000 --- a/include/mm/mmiterator.hpp +++ /dev/null @@ -1,248 +0,0 @@ -#pragma once - -#include "debug.hpp" - -namespace mm::iter { - - template - class vector_iterator; - - template - class basic_iterator; - - template - class diag_iterator; -} - -template -class mm::iter::vector_iterator -{ -public: - - template - friend class mm::iter::basic_iterator; - - template - friend class mm::iter::diag_iterator; - - vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) - : M(_M), position(pos), index(i) {} - -//#ifdef MM_IMPLICIT_CONVERSION_ITERATOR - operator T&() - { - //npdebug("Calling +") - return *(*this); - } -//#endif - - IterType operator++() - { - IterType it = cpy(); - ++index; - return it; - } - - IterType operator--() - { - IterType it = cpy(); - --index; - return it; - } - - IterType& operator++(int) - { - ++index; - return ref(); - } - - IterType& operator--(int) - { - --index; - return ref(); - } - - bool operator==(const IterType& other) const - { - return index == other.index; - } - - bool operator!=(const IterType& other) const - { - return index != other.index; - } - - bool ok() const - { - return index < size(); - } - - virtual std::size_t size() const = 0; - - virtual T& operator*() = 0; - virtual T& operator[](std::size_t) = 0; - - virtual T& operator[](std::size_t) const = 0; - - IterType begin() - { - return IterType(M, position, 0); - } - - virtual IterType end() = 0; - -protected: - - Grid& M; // grid mapping - - const std::size_t position; // fixed index, negative too for diagonal iterator - std::size_t index; // variable index - - virtual IterType& ref() = 0; - virtual IterType cpy() = 0; -}; - - -/* - * Scalar product - */ - -template -typename std::remove_const::type operator*(const mm::iter::vector_iterator& v, - const mm::iter::vector_iterator& w) -{ - typename std::remove_const::type out(0); - const std::size_t N = std::min(v.size(), w.size()); - - for(unsigned i = 0; i < N; ++i) - out += v[i] * w[i]; - - return out; -} - -template -class mm::iter::basic_iterator : public mm::iter::vector_iterator, Grid> -{ - bool direction; - - virtual mm::iter::basic_iterator& ref() override - { - return *this; - } - - virtual mm::iter::basic_iterator cpy() override - { - return *this; - } - -public: - - basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true) - : mm::iter::vector_iterator, Grid> - (A, pos, _index), direction(dir) - { - //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir) - - if (direction) - assert(pos < Rows); - else - assert(pos < Cols); - } - - virtual std::size_t size() const - { - return (direction) ? Cols : Rows; - } - - - virtual T& operator*() override - { - return (direction) ? - this->M.data[this->position * Cols + this->index] : - this->M.data[this->index * Cols + this->position]; - - } - - virtual T& operator[](std::size_t i) override - { - return (direction) ? - this->M.data[this->position * Cols + i] : - this->M.data[i * Cols + this->position]; - } - - virtual T& operator[](std::size_t i) const override - { - return (direction) ? - this->M.data[this->position * Cols + i] : - this->M.data[i * Cols + this->position]; - } - - virtual mm::iter::basic_iterator end() - { - return mm::iter::basic_iterator(this->M, this->position, - (direction) ? Cols : Rows); - } -}; - -template -class mm::iter::diag_iterator : public mm::iter::vector_iterator, Grid> -{ - bool sign; - - virtual mm::iter::diag_iterator& ref() override - { - return *this; - } - - virtual mm::iter::diag_iterator cpy() override - { - return *this; - } - -public: - - diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0) - : mm::iter::vector_iterator, Grid> - (A, static_cast(labs(pos)), _index), sign(pos >= 0) - { - assert(this->position < N); - } - - virtual std::size_t size() const - { - return N - this->position; - } - - virtual T& operator*() override - { - return (sign) ? - this->M.data[(this->index - this->position) * N + this->index] : - this->M.data[this->index * N + (this->index + this->position)]; - } - - virtual T& operator[](std::size_t i) override - { - return (sign) ? - this->M.data[(i - this->position) * N + i] : - this->M.data[i * N + (i + this->position)]; - } - - virtual T& operator[](std::size_t i) const override - { - return (sign) ? - this->M.data[(i - this->position) * N + i] : - this->M.data[i * N + (i + this->position)]; - } - - - virtual mm::iter::diag_iterator end() - { - return mm::iter::diag_iterator(this->M, this->position, N); - } -}; - - diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index facbb48..d653d72 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -12,580 +12,117 @@ #pragma once #include -#include -#include #include #include #include #include -#include "mm/mmiterator.hpp" namespace mm { - - /* basic grid structure */ + using index = std::size_t; template class basic_matrix; - /* basic wrapper */ - - template - class matrix; // simple matrix format - /* specialisations */ - - /* specialization of a matrix */ - template - class vector; // by default, set a column vector + template + class matrix; template class square_matrix; - /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template - class multi_diag_matrix; + template + class diagonal_matrix; } /* * Matrix class, no access methods */ - -template -class mm::basic_matrix -{ -public: - using type = T; - - template - friend class mm::basic_matrix; - - template - friend class mm::matrix; - - template - friend class mm::iter::basic_iterator; - - template - friend class mm::iter::diag_iterator; - - //template - //friend class mm::iter::basic_iterator>; - - //template - //friend class mm::iter::basic_iterator::type, Rows, Cols, typename std::add_const>::type>; - - basic_matrix(); - - // from initializer_list - basic_matrix(std::initializer_list> l); - - // copyable and movable - basic_matrix(const basic_matrix& other) = default; - basic_matrix(basic_matrix&& other) = default; - - // copy from another matrix - template - basic_matrix(const basic_matrix& other); - - void swap_rows(std::size_t x, std::size_t y); - void swap_cols(std::size_t x, std::size_t y); - - // mathematical operations - //virtual basic_matrix transposed() const; - //inline basic_matrix td() const { return transposed(); } - -protected: - template - basic_matrix(ConstIterator begin, ConstIterator end); - -private: - std::array data; -}; - - -template -mm::basic_matrix::basic_matrix() { - std::fill(data.begin(), data.end(), 0); -} - -template -mm::basic_matrix::basic_matrix( - std::initializer_list> l -) { - assert(l.size() == Rows); - auto data_it = data.begin(); - - for (auto&& row : l) { - data_it = std::copy(row.begin(), row.end(), data_it); - } -} - -template -template -mm::basic_matrix::basic_matrix( - const mm::basic_matrix& other -) { - static_assert((ORows <= Rows), - "cannot copy a taller matrix into a smaller one" - ); - - static_assert((OCols <= Cols), - "cannot copy a larger matrix into a smaller one" - ); - - std::fill(data.begin(), data.end(), 0); - for (unsigned row = 0; row < Rows; row++) - for (unsigned col = 0; col < Cols; col++) - this->at(row, col) = other.at(row, col); -} - -/* protected construtor */ -template -template -mm::basic_matrix::basic_matrix( - ConstIterator begin, ConstIterator end -) { - assert(static_cast(std::distance(begin, end)) >= ((Rows * Cols))); - std::copy(begin, end, data.begin()); -} - -template -void mm::basic_matrix::swap_rows(std::size_t x, std::size_t y) { - if (x == y) - return; - - for (unsigned col = 0; col < Cols; col++) - std::swap(this->at(x, col), this->at(y, col)); -} - -template -void mm::basic_matrix::swap_cols(std::size_t x, std::size_t y) { - if (x == y) - return; - - for (unsigned row = 0; row < Rows; row++) - std::swap(this->at(row, x), this->at(row, y)); -} - -/* - * Matrix object - */ - -template -class mm::matrix -{ -public: - - //template - using vec_iterator = mm::iter::basic_iterator>; - - //template - using const_vec_iterator = mm::iter::basic_iterator::type, Rows, Cols, typename std::add_const>::type>; - - // default zeros constructor - matrix() : M(std::make_shared>()), transposed(false) {} - - // from initializer_list - matrix(std::initializer_list> l) - : M(std::make_shared>(l)), transposed(false) {} - - // copyable and movable - matrix(const matrix& other) // deep copy - : M(std::make_shared>(*other.M)), transposed(other.transposed) {} - - matrix(basic_matrix&& other) // move ptr - : M(other.M), transposed(other.transposed) - { - other.M = nullptr; - } - - matrix operator=(const basic_matrix& other) // deep copy - { - *M = *other.M; - transposed = other.transposed; - } - - /* - * Transposition - */ - - matrix& transpose_d() - { - transposed = !transposed; - return *this; - } - - const matrix transpose() const - { - return matrix(M, !transposed); - } - - inline matrix& td() - { - return transpose(); - } - - inline matrix t() const - { - return transpose(); - } - - // strongly transpose - matrix transpose_cpy() const - { - matrix out(); // copy - // TODO - } - - /* - * Pointer status - */ - - bool expired() const +namespace mm { + template + class basic_matrix { - return M == nullptr; - } - - /* - * Downcasting conditions - */ - - /// downcast to square matrix - static inline constexpr bool is_square() { return (Rows == Cols); } - inline constexpr square_matrix to_square() const { - static_assert(is_square()); - return static_cast>(*this); - } - - /// downcast to col_vector - static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } - inline vector to_vector() const { - if constexpr(Cols == 1) - return static_cast>(*this); - else if (Rows == 1) - return vector(*this); // copy into column vector + public: + using type = T; - } + template + friend class mm::matrix; - /* Accessors */ + // copy from another matrix + template + matrix(const basic_matrix& other); - virtual T& at(std::size_t row, std::size_t col) - { - return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; - } + virtual T& at(index row, index col) = 0; + virtual const T& at(index row, index col) const = 0; + }; - virtual const T& at(std::size_t row, std::size_t col) const - { - return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; - } - std::size_t rows() const { - return (transposed) ? Cols : Rows; - } - - std::size_t cols() const { - return (transposed) ? Rows : Cols; - } - - virtual mm::matrix::vec_iterator operator[](std::size_t index) - { - return mm::matrix::vec_iterator(*M, index, 0, !transposed); - } + /* Specializations */ - virtual mm::matrix::const_vec_iterator operator[](std::size_t index) const + template + struct matrix : public basic_matrix { - return mm::matrix::const_vec_iterator(*M, index, 0, !transposed); - } - - /* - * Basic matematical operations (dimension indipendent) - */ - - mm::matrix& operator+=(const mm::matrix& m) { - - for (unsigned row = 0; row < std::min(rows(), m.rows()); ++row) - for (unsigned col = 0; col < std::min(cols(), m.cols()); ++col) - at(row, col) += m.at(row, col); - - return *this; - } - - mm::matrix& operator-=(const mm::matrix& m) { - - for (unsigned row = 0; row < std::min(rows(), m.rows()); ++row) - for (unsigned col = 0; col < std::min(cols(), m.cols()); ++col) - at(row, col) -= m.at(row, col); - - return *this; - } - - mm::matrix operator*=(const T& k) { - - for (unsigned row = 0; row < rows(); ++row) - for (auto& x : (*this)[row]) - x *= k; - - return *this; - } - -protected: - - std::shared_ptr> M; - - // shallow construction - matrix(std::shared_ptr> grid, bool tr = false) : M(grid), transposed(tr) {} - -private: - - bool transposed; -}; - -/* Basic operator overloading (dimension indipendent) */ - -template -mm::matrix operator+( - mm::matrix a, - const mm::matrix& b -) { - return a += b; -} - -template -mm::matrix operator-( - mm::matrix a, - const mm::matrix& b -) { - return a -= b; -} - -template -mm::matrix operator*( - mm::matrix a, - const T& k -) { - return a *= k; -} - -template -mm::matrix operator*( - const T& k, - mm::matrix a -) { - return a *= k; -} - -// simple multiplication -template -mm::matrix operator*( - const mm::matrix& a, - const mm::matrix& b -) { - // TODO, adjust asserts for transposed cases - static_assert(P1 == P2, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix result; - const mm::matrix bt = b.t(); // weak transposition - - //npdebug("Calling *") - - for (unsigned row = 0; row < M; row++) - for (unsigned col = 0; col < N; col++) - result.at(row, col) = a[row] * bt[col]; // scalar product - - return result; -} - -/* - * Matrix operator << - */ - -template -std::ostream& operator<<(std::ostream& os, const mm::matrix& m) { - - for (unsigned index = 0; index < m.rows(); index++) { - os << "[ "; - for (unsigned col = 0; col < m.cols()-1; ++col) { - os << std::setw(NumW) << m.at(index, col) << ", "; + public: + virtual T& at(index row, index col) override { + return m_data[row * Cols + col]; } - os << std::setw(NumW) << m.at(index, m.cols()-1) << " ]\n"; - } - - return os; -} - -/* - * Vector, TODO better manage column and row - */ - -template -class mm::vector : public mm::matrix -{ -public: - - using mm::matrix::matrix; - - vector(std::initializer_list l) - : mm::matrix(l) {} -}; -template -mm::vector operator^(const mm::vector& v, const mm::vector& w) -{ - mm::vector out; - - out[0] = v[1] * w[2] - v[2] * w[2]; - out[1] = v[2] * w[0] - v[0] * w[2]; - out[2] = v[0] * w[1] - v[1] * w[0]; - - return out; -} - -/* - * Square matrix - */ - -template -class mm::square_matrix : public mm::matrix -{ -public: - - using mm::matrix::matrix; - - using diag_iterator = mm::iter::diag_iterator>; - - using const_diag_iterator = mm::iter::diag_iterator::type, N, typename std::add_const>::type>; - - virtual T trace(); - inline T tr() { return trace(); } - - virtual mm::square_matrix::diag_iterator diag_beg(int row = 0) - { - return diag_iterator(*(this->M), row, 0); - } - - virtual mm::square_matrix::const_diag_iterator diag_end(int row = 0) const - { - return const_diag_iterator(*(this->M), row, N); - } - - // TODO, determinant - - /// in place inverse - // TODO, det != 0 - // TODO, use gauss jordan for invertible ones - //void invert();, TODO, section algorithm - - /* - * Generate the identity - */ - - static inline constexpr mm::square_matrix identity() { - mm::square_matrix i; - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < N; col++) - i.at(row, col) = (row == col) ? 1 : 0; - - return i; - } -}; - -template -T mm::square_matrix::trace() -{ - T sum = 0; - for (const auto& x : diag_beg()) - sum += x; + virtual const T& at(index row, index col) const override { + return at(row, col); + } - return sum; -} + private: + std::array m_data; + }; -// TODO, static assert, for all: Diags > -N, Diags < N -// TODO, force Diags to be ordered -template -class mm::multi_diag_matrix : public mm::square_matrix -{ - T& shared_zero = 0; -public: - using mm::square_matrix::square_matrix; + template + struct vector : public matrix {}; - // TODO, ordered case: dichotomy search O(log(M)) - // M = parameter pack size - static inline bool constexpr is_in(std::size_t i, std::size_t j) + template + struct square_matrix : public basic_matrix { - auto t = std::make_tuple(Diags...); + public: + virtual T& at(index row, index col) override { + return m_data[row * N + col]; + } - for(unsigned k(0); k < sizeof...(Diags); ++k) - if ((i - j) == std::get(t)) - return true; + virtual const T& at(index row, index col) const override { + return at(row, col); + } - return false; - } + private: + std::array m_data; + }; - virtual T& at(std::size_t row, std::size_t col) override + template + struct identity_matrix : public basic_matrix { - if (is_in(row, col)) - return mm::square_matrix::at(row, col); + public: + const T& at(index row, index col) const override { + return (row != col) ? static_cast(1) : static_cast(0); + } - shared_zero = 0; - return shared_zero; + private: + T m_useless; + T& at(index row, index col) { return m_useless; } } - virtual const T& at(std::size_t row, std::size_t col) const override + template + struct diagonal_matrix : public basic_matrix { - if (is_in(row, col)) - return mm::square_matrix::at(row, col); - - return 0; - } - - - - // TODO, implement limited iterators -}; + public: + T& at(index row, index col) override { + n_null_element = static_cast(0); + return (row != col) ? m_data[row] : n_null_element; + } -/*template -void constexpr diag_mult(const mm::multi_diag_matrix& a, - const mm::matrix& b, mm::matrix& result) -{ - static_assert(N == P && N == M, "invalid diagonal multiplication"); + const T& at(index row, index col) const override { + return (row != col) ? m_data[row] : static_cast(0); + } - auto d = a.diagonal(Diags); - if constexpr (Diags < 0) { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = -Diags; i < N; ++i) - result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); - } else { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = Diags; i < N; ++i) - result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + private: + T m_null_element; + std::array m_data; } } - -template -mm::matrix operator*( - const mm::multi_diag_matrix& a, - const mm::matrix& b -) { - static_assert(N == P && N == M, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix result; - (( - auto d = a.diagonal(Diags); - if constexpr (Diags < 0) { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = -Diags; i < N; ++i) - result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); - } else { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = Diags; i < N; ++i) - result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); - } - ) ...); - - return result; -}*/ - diff --git a/include/mm/view.hpp b/include/mm/view.hpp new file mode 100644 index 0000000..910c16a --- /dev/null +++ b/include/mm/view.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include + + +namespace mm::alg { + + template < + template typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct visitor + { + using type = T; + + // copy constructible + visitor(const visitor& other) = default; + + T& operator()(const Matrix& m, index row, index col) { + return m.at(row, col); + } + + const T& operator()(const Matrix& m, index row, index col) { + return operator()(m, row, col); + } + }; + + template < + template typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct transpose : public visitor + { + T& operator()(const Matrix m, index row, index col) { + // assert(col < Rows) + // assert(row < Cols) + return m.at(col, row); + } + }; +} + +namespace mm { + template < + template typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct view + { + Matrix& m; + // std::stack> visitors; + std::unique_ptr visitor; + + T& at(index row, index col) { + return visitor(m, row, col); + } + + view& operator|=(const alg::visitor& other) { + // visitors.push(std::move(std::make_unique(other))); + visitor = std::make_unique(other); + } + }; + + view operator|(const view& left, const alg::visitor& right) { + return left |= right; + } +} diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index a3f0eac..96ba67d 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -32,27 +32,5 @@ int main(int argc, char *argv[]) { std::cout << "a.td() = \n" << a.t(); // or a.trasposed(); std::cout << std::endl; - // square matrix - mm::square_matrix, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; - - std::cout << "Square matrix" << std::endl; - std::cout << "f = \n" << f; - - std::cout << "tr(f) = " << f.tr(); //or f.trace() << std::endl; - - auto ft = f.t(); - std::cout << "after in place transpose f.t(), f = \n" << ft; - std::cout << std::endl; - - auto identity = mm::square_matrix::identity(); - - std::cout << "Identity matrix" << std::endl; - std::cout << "I = \n" << identity; - std::cout << std::endl; - - // vector - - // - return 0; } -- cgit v1.2.1