diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/mm/mmiterator.hpp | 248 | ||||
-rw-r--r-- | include/mm/mmmatrix.hpp | 595 | ||||
-rw-r--r-- | include/mm/view.hpp | 66 |
3 files changed, 132 insertions, 777 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp deleted file mode 100644 index 4efe474..0000000 --- a/include/mm/mmiterator.hpp +++ /dev/null @@ -1,248 +0,0 @@ -#pragma once - -#include "debug.hpp" - -namespace mm::iter { - - template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid> - class vector_iterator; - - template<typename T, std::size_t Rows, std::size_t Cols, class Grid> - class basic_iterator; - - template<typename T, std::size_t N, class Grid> - class diag_iterator; -} - -template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid> -class mm::iter::vector_iterator -{ -public: - - template<typename U, std::size_t R, std::size_t C, class G> - friend class mm::iter::basic_iterator; - - template<typename U, std::size_t N, class G> - friend class mm::iter::diag_iterator; - - vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) - : M(_M), position(pos), index(i) {} - -//#ifdef MM_IMPLICIT_CONVERSION_ITERATOR - operator T&() - { - //npdebug("Calling +") - return *(*this); - } -//#endif - - IterType operator++() - { - IterType it = cpy(); - ++index; - return it; - } - - IterType operator--() - { - IterType it = cpy(); - --index; - return it; - } - - IterType& operator++(int) - { - ++index; - return ref(); - } - - IterType& operator--(int) - { - --index; - return ref(); - } - - bool operator==(const IterType& other) const - { - return index == other.index; - } - - bool operator!=(const IterType& other) const - { - return index != other.index; - } - - bool ok() const - { - return index < size(); - } - - virtual std::size_t size() const = 0; - - virtual T& operator*() = 0; - virtual T& operator[](std::size_t) = 0; - - virtual T& operator[](std::size_t) const = 0; - - IterType begin() - { - return IterType(M, position, 0); - } - - virtual IterType end() = 0; - -protected: - - Grid& M; // grid mapping - - const std::size_t position; // fixed index, negative too for diagonal iterator - std::size_t index; // variable index - - virtual IterType& ref() = 0; - virtual IterType cpy() = 0; -}; - - -/* - * Scalar product - */ - -template<typename T, - std::size_t R1, std::size_t C1, - std::size_t R2, std::size_t C2, - class IterType1, class IterType2, - class Grid1, class Grid2> -typename std::remove_const<T>::type operator*(const mm::iter::vector_iterator<T, R1, C1, IterType1, Grid1>& v, - const mm::iter::vector_iterator<T, R2, C2, IterType2, Grid2>& w) -{ - typename std::remove_const<T>::type out(0); - const std::size_t N = std::min(v.size(), w.size()); - - for(unsigned i = 0; i < N; ++i) - out += v[i] * w[i]; - - return out; -} - -template<typename T, std::size_t Rows, std::size_t Cols, class Grid> -class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid> -{ - bool direction; - - virtual mm::iter::basic_iterator<T, Rows, Cols, Grid>& ref() override - { - return *this; - } - - virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> cpy() override - { - return *this; - } - -public: - - basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true) - : mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid> - (A, pos, _index), direction(dir) - { - //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir) - - if (direction) - assert(pos < Rows); - else - assert(pos < Cols); - } - - virtual std::size_t size() const - { - return (direction) ? Cols : Rows; - } - - - virtual T& operator*() override - { - return (direction) ? - this->M.data[this->position * Cols + this->index] : - this->M.data[this->index * Cols + this->position]; - - } - - virtual T& operator[](std::size_t i) override - { - return (direction) ? - this->M.data[this->position * Cols + i] : - this->M.data[i * Cols + this->position]; - } - - virtual T& operator[](std::size_t i) const override - { - return (direction) ? - this->M.data[this->position * Cols + i] : - this->M.data[i * Cols + this->position]; - } - - virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> end() - { - return mm::iter::basic_iterator<T, Rows, Cols, Grid>(this->M, this->position, - (direction) ? Cols : Rows); - } -}; - -template<typename T, std::size_t N, class Grid> -class mm::iter::diag_iterator : public mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid> -{ - bool sign; - - virtual mm::iter::diag_iterator<T, N, Grid>& ref() override - { - return *this; - } - - virtual mm::iter::diag_iterator<T, N, Grid> cpy() override - { - return *this; - } - -public: - - diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0) - : mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid> - (A, static_cast<std::size_t>(labs(pos)), _index), sign(pos >= 0) - { - assert(this->position < N); - } - - virtual std::size_t size() const - { - return N - this->position; - } - - virtual T& operator*() override - { - return (sign) ? - this->M.data[(this->index - this->position) * N + this->index] : - this->M.data[this->index * N + (this->index + this->position)]; - } - - virtual T& operator[](std::size_t i) override - { - return (sign) ? - this->M.data[(i - this->position) * N + i] : - this->M.data[i * N + (i + this->position)]; - } - - virtual T& operator[](std::size_t i) const override - { - return (sign) ? - this->M.data[(i - this->position) * N + i] : - this->M.data[i * N + (i + this->position)]; - } - - - virtual mm::iter::diag_iterator<T, N, Grid> end() - { - return mm::iter::diag_iterator<T, N, Grid>(this->M, this->position, N); - } -}; - - diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index facbb48..d653d72 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -12,580 +12,117 @@ #pragma once #include <iostream> -#include <iomanip> -#include <cstring> #include <cassert> #include <initializer_list> #include <array> #include <memory> -#include "mm/mmiterator.hpp" namespace mm { - - /* basic grid structure */ + using index = std::size_t; template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; - /* basic wrapper */ - - template<typename T, std::size_t Rows, std::size_t Cols> - class matrix; // simple matrix format - /* specialisations */ - - /* specialization of a matrix */ - template<typename T, std::size_t N> - class vector; // by default, set a column vector + template<typename T, std::size_t Rows, std::size_t Cols> + class matrix; template<typename T, std::size_t N> class square_matrix; - /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template<typename T, std::size_t N, signed long ... Diags> - class multi_diag_matrix; + template<typename T, std::size_t N> + class diagonal_matrix; } /* * Matrix class, no access methods */ - -template<typename T, std::size_t Rows, std::size_t Cols> -class mm::basic_matrix -{ -public: - using type = T; - - template<typename U, std::size_t ORows, std::size_t OCols> - friend class mm::basic_matrix; - - template<typename U, std::size_t ORows, std::size_t OCols> - friend class mm::matrix; - - template<typename U, std::size_t ORows, std::size_t OCols, class Grid> - friend class mm::iter::basic_iterator; - - template<typename U, std::size_t ON, class Grid> - friend class mm::iter::diag_iterator; - - //template<typename U, std::size_t ORows, std::size_t OCols, class Grid> - //friend class mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>; - - //template<typename U, std::size_t ORows, std::size_t OCols, class Grid> - //friend class mm::iter::basic_iterator<typename std::add_const<T>::type, Rows, Cols, typename std::add_const<mm::basic_matrix<T, Rows, Cols>>::type>; - - basic_matrix(); - - // from initializer_list - basic_matrix(std::initializer_list<std::initializer_list<T>> l); - - // copyable and movable - basic_matrix(const basic_matrix<T, Rows, Cols>& other) = default; - basic_matrix(basic_matrix<T, Rows, Cols>&& other) = default; - - // copy from another matrix - template<std::size_t ORows, std::size_t OCols> - basic_matrix(const basic_matrix<T, ORows, OCols>& other); - - void swap_rows(std::size_t x, std::size_t y); - void swap_cols(std::size_t x, std::size_t y); - - // mathematical operations - //virtual basic_matrix<T, Cols, Rows> transposed() const; - //inline basic_matrix<T, Cols, Rows> td() const { return transposed(); } - -protected: - template<typename ConstIterator> - basic_matrix(ConstIterator begin, ConstIterator end); - -private: - std::array<T, Rows * Cols> data; -}; - - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::basic_matrix<T, Rows, Cols>::basic_matrix() { - std::fill(data.begin(), data.end(), 0); -} - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::basic_matrix<T, Rows, Cols>::basic_matrix( - std::initializer_list<std::initializer_list<T>> l -) { - assert(l.size() == Rows); - auto data_it = data.begin(); - - for (auto&& row : l) { - data_it = std::copy(row.begin(), row.end(), data_it); - } -} - -template<typename T, std::size_t Rows, std::size_t Cols> -template<std::size_t ORows, std::size_t OCols> -mm::basic_matrix<T, Rows, Cols>::basic_matrix( - const mm::basic_matrix<T, ORows, OCols>& other -) { - static_assert((ORows <= Rows), - "cannot copy a taller matrix into a smaller one" - ); - - static_assert((OCols <= Cols), - "cannot copy a larger matrix into a smaller one" - ); - - std::fill(data.begin(), data.end(), 0); - for (unsigned row = 0; row < Rows; row++) - for (unsigned col = 0; col < Cols; col++) - this->at(row, col) = other.at(row, col); -} - -/* protected construtor */ -template<typename T, std::size_t Rows, std::size_t Cols> -template<typename ConstIterator> -mm::basic_matrix<T, Rows, Cols>::basic_matrix( - ConstIterator begin, ConstIterator end -) { - assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols))); - std::copy(begin, end, data.begin()); -} - -template<typename T, std::size_t Rows, std::size_t Cols> -void mm::basic_matrix<T, Rows, Cols>::swap_rows(std::size_t x, std::size_t y) { - if (x == y) - return; - - for (unsigned col = 0; col < Cols; col++) - std::swap(this->at(x, col), this->at(y, col)); -} - -template<typename T, std::size_t Rows, std::size_t Cols> -void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) { - if (x == y) - return; - - for (unsigned row = 0; row < Rows; row++) - std::swap(this->at(row, x), this->at(row, y)); -} - -/* - * Matrix object - */ - -template<typename T, std::size_t Rows, std::size_t Cols> -class mm::matrix -{ -public: - - //template<typename U, std::size_t ORows, std::size_t OCols> - using vec_iterator = mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>; - - //template<typename U, std::size_t ORows, std::size_t OCols> - using const_vec_iterator = mm::iter::basic_iterator<typename std::add_const<T>::type, Rows, Cols, typename std::add_const<mm::basic_matrix<T, Rows, Cols>>::type>; - - // default zeros constructor - matrix() : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>>()), transposed(false) {} - - // from initializer_list - matrix(std::initializer_list<std::initializer_list<T>> l) - : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>>(l)), transposed(false) {} - - // copyable and movable - matrix(const matrix<T, Rows, Cols>& other) // deep copy - : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>>(*other.M)), transposed(other.transposed) {} - - matrix(basic_matrix<T, Rows, Cols>&& other) // move ptr - : M(other.M), transposed(other.transposed) - { - other.M = nullptr; - } - - matrix<T, Rows, Cols> operator=(const basic_matrix<T, Rows, Cols>& other) // deep copy - { - *M = *other.M; - transposed = other.transposed; - } - - /* - * Transposition - */ - - matrix<T, Rows, Cols>& transpose_d() - { - transposed = !transposed; - return *this; - } - - const matrix<T, Rows, Cols> transpose() const - { - return matrix<T, Rows, Cols>(M, !transposed); - } - - inline matrix<T, Rows, Cols>& td() - { - return transpose(); - } - - inline matrix<T, Rows, Cols> t() const - { - return transpose(); - } - - // strongly transpose - matrix<T, Cols, Rows> transpose_cpy() const - { - matrix<T, Cols, Rows> out(); // copy - // TODO - } - - /* - * Pointer status - */ - - bool expired() const +namespace mm { + template<typename T, std::size_t Rows, std::size_t Cols> + class basic_matrix { - return M == nullptr; - } - - /* - * Downcasting conditions - */ - - /// downcast to square matrix - static inline constexpr bool is_square() { return (Rows == Cols); } - inline constexpr square_matrix<T, Rows> to_square() const { - static_assert(is_square()); - return static_cast<square_matrix<T, Rows>>(*this); - } - - /// downcast to col_vector - static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } - inline vector<T, Cols> to_vector() const { - if constexpr(Cols == 1) - return static_cast<vector<T, Rows>>(*this); - else if (Rows == 1) - return vector<T, Cols>(*this); // copy into column vector + public: + using type = T; - } + template<typename U, std::size_t ORows, std::size_t OCols> + friend class mm::matrix; - /* Accessors */ + // copy from another matrix + template<std::size_t ORows, std::size_t OCols> + matrix(const basic_matrix<T, ORows, OCols>& other); - virtual T& at(std::size_t row, std::size_t col) - { - return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; - } + virtual T& at(index row, index col) = 0; + virtual const T& at(index row, index col) const = 0; + }; - virtual const T& at(std::size_t row, std::size_t col) const - { - return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; - } - std::size_t rows() const { - return (transposed) ? Cols : Rows; - } - - std::size_t cols() const { - return (transposed) ? Rows : Cols; - } - - virtual mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) - { - return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed); - } + /* Specializations */ - virtual mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const + template<typename T, std::size_t Rows, std::size_t Cols> + struct matrix : public basic_matrix<T, N> { - return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed); - } - - /* - * Basic matematical operations (dimension indipendent) - */ - - mm::matrix<T, Rows, Cols>& operator+=(const mm::matrix<T, Rows, Cols>& m) { - - for (unsigned row = 0; row < std::min(rows(), m.rows()); ++row) - for (unsigned col = 0; col < std::min(cols(), m.cols()); ++col) - at(row, col) += m.at(row, col); - - return *this; - } - - mm::matrix<T, Rows, Cols>& operator-=(const mm::matrix<T, Rows, Cols>& m) { - - for (unsigned row = 0; row < std::min(rows(), m.rows()); ++row) - for (unsigned col = 0; col < std::min(cols(), m.cols()); ++col) - at(row, col) -= m.at(row, col); - - return *this; - } - - mm::matrix<T, Rows, Cols> operator*=(const T& k) { - - for (unsigned row = 0; row < rows(); ++row) - for (auto& x : (*this)[row]) - x *= k; - - return *this; - } - -protected: - - std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> M; - - // shallow construction - matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {} - -private: - - bool transposed; -}; - -/* Basic operator overloading (dimension indipendent) */ - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::matrix<T, Rows, Cols> operator+( - mm::matrix<T, Rows, Cols> a, - const mm::matrix<T, Rows, Cols>& b -) { - return a += b; -} - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::matrix<T, Rows, Cols> operator-( - mm::matrix<T, Rows, Cols> a, - const mm::matrix<T, Rows, Cols>& b -) { - return a -= b; -} - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::matrix<T, Rows, Cols> operator*( - mm::matrix<T, Rows, Cols> a, - const T& k -) { - return a *= k; -} - -template<typename T, std::size_t Rows, std::size_t Cols> -mm::matrix<T, Rows, Cols> operator*( - const T& k, - mm::matrix<T, Rows, Cols> a -) { - return a *= k; -} - -// simple multiplication -template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N> -mm::matrix<T, M, N> operator*( - const mm::matrix<T, M, P1>& a, - const mm::matrix<T, P2, N>& b -) { - // TODO, adjust asserts for transposed cases - static_assert(P1 == P2, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix<T, M, N> result; - const mm::matrix<T, P2, N> bt = b.t(); // weak transposition - - //npdebug("Calling *") - - for (unsigned row = 0; row < M; row++) - for (unsigned col = 0; col < N; col++) - result.at(row, col) = a[row] * bt[col]; // scalar product - - return result; -} - -/* - * Matrix operator << - */ - -template<typename T, std::size_t Rows, std::size_t Cols, unsigned NumW = 3> -std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) { - - for (unsigned index = 0; index < m.rows(); index++) { - os << "[ "; - for (unsigned col = 0; col < m.cols()-1; ++col) { - os << std::setw(NumW) << m.at(index, col) << ", "; + public: + virtual T& at(index row, index col) override { + return m_data[row * Cols + col]; } - os << std::setw(NumW) << m.at(index, m.cols()-1) << " ]\n"; - } - - return os; -} - -/* - * Vector, TODO better manage column and row - */ - -template<typename T, std::size_t N> -class mm::vector : public mm::matrix<T, N, 1> -{ -public: - - using mm::matrix<T, N, 1>::matrix; - - vector(std::initializer_list<T> l) - : mm::matrix<T, N, 1>(l) {} -}; -template<typename T> -mm::vector<T, 3> operator^(const mm::vector<T, 3>& v, const mm::vector<T, 3>& w) -{ - mm::vector<T, 3> out; - - out[0] = v[1] * w[2] - v[2] * w[2]; - out[1] = v[2] * w[0] - v[0] * w[2]; - out[2] = v[0] * w[1] - v[1] * w[0]; - - return out; -} - -/* - * Square matrix - */ - -template<typename T, std::size_t N> -class mm::square_matrix : public mm::matrix<T, N, N> -{ -public: - - using mm::matrix<T, N, N>::matrix; - - using diag_iterator = mm::iter::diag_iterator<T, N, mm::basic_matrix<T, N, N>>; - - using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; - - virtual T trace(); - inline T tr() { return trace(); } - - virtual mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) - { - return diag_iterator(*(this->M), row, 0); - } - - virtual mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const - { - return const_diag_iterator(*(this->M), row, N); - } - - // TODO, determinant - - /// in place inverse - // TODO, det != 0 - // TODO, use gauss jordan for invertible ones - //void invert();, TODO, section algorithm - - /* - * Generate the identity - */ - - static inline constexpr mm::square_matrix<T, N> identity() { - mm::square_matrix<T, N> i; - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < N; col++) - i.at(row, col) = (row == col) ? 1 : 0; - - return i; - } -}; - -template<typename T, std::size_t N> -T mm::square_matrix<T, N>::trace() -{ - T sum = 0; - for (const auto& x : diag_beg()) - sum += x; + virtual const T& at(index row, index col) const override { + return at(row, col); + } - return sum; -} + private: + std::array<T, Rows * Cols> m_data; + }; -// TODO, static assert, for all: Diags > -N, Diags < N -// TODO, force Diags to be ordered -template<typename T, std::size_t N, signed long ... Diags> -class mm::multi_diag_matrix : public mm::square_matrix<T, N> -{ - T& shared_zero = 0; -public: - using mm::square_matrix<T, N>::square_matrix; + template<typename T, std::size_t N> + struct vector : public matrix<T, 1, N> {}; - // TODO, ordered case: dichotomy search O(log(M)) - // M = parameter pack size - static inline bool constexpr is_in(std::size_t i, std::size_t j) + template<typename T, std::size_t N> + struct square_matrix : public basic_matrix<T, N> { - auto t = std::make_tuple(Diags...); + public: + virtual T& at(index row, index col) override { + return m_data[row * N + col]; + } - for(unsigned k(0); k < sizeof...(Diags); ++k) - if ((i - j) == std::get<k>(t)) - return true; + virtual const T& at(index row, index col) const override { + return at(row, col); + } - return false; - } + private: + std::array<T, N*N> m_data; + }; - virtual T& at(std::size_t row, std::size_t col) override + template<typename T, std::size_t N> + struct identity_matrix : public basic_matrix<T, N, N> { - if (is_in(row, col)) - return mm::square_matrix<T, N>::at(row, col); + public: + const T& at(index row, index col) const override { + return (row != col) ? static_cast<T>(1) : static_cast<T>(0); + } - shared_zero = 0; - return shared_zero; + private: + T m_useless; + T& at(index row, index col) { return m_useless; } } - virtual const T& at(std::size_t row, std::size_t col) const override + template<typename T, std::size_t N> + struct diagonal_matrix : public basic_matrix<T, N, N> { - if (is_in(row, col)) - return mm::square_matrix<T, N>::at(row, col); - - return 0; - } - - - - // TODO, implement limited iterators -}; + public: + T& at(index row, index col) override { + n_null_element = static_cast<T>(0); + return (row != col) ? m_data[row] : n_null_element; + } -/*template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> -void constexpr diag_mult(const mm::multi_diag_matrix<T, N, Diags...>& a, - const mm::matrix<T, P, M>& b, mm::matrix<T, M, N>& result) -{ - static_assert(N == P && N == M, "invalid diagonal multiplication"); + const T& at(index row, index col) const override { + return (row != col) ? m_data[row] : static_cast<T>(0); + } - auto d = a.diagonal(Diags); - if constexpr (Diags < 0) { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = -Diags; i < N; ++i) - result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); - } else { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = Diags; i < N; ++i) - result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + private: + T m_null_element; + std::array<T, N> m_data; } } - -template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> -mm::matrix<T, M, N> operator*( - const mm::multi_diag_matrix<T, N, Diags...>& a, - const mm::matrix<T, P, M>& b -) { - static_assert(N == P && N == M, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix<T, M, N> result; - (( - auto d = a.diagonal(Diags); - if constexpr (Diags < 0) { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = -Diags; i < N; ++i) - result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); - } else { - for (unsigned k = 0; k < M; ++k) - for (unsigned i = Diags; i < N; ++i) - result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); - } - ) ...); - - return result; -}*/ - diff --git a/include/mm/view.hpp b/include/mm/view.hpp new file mode 100644 index 0000000..910c16a --- /dev/null +++ b/include/mm/view.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include <mmmatrix.hpp> + + +namespace mm::alg { + + template < + template<typename, std::size_t, std::size_t> typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct visitor + { + using type = T; + + // copy constructible + visitor(const visitor& other) = default; + + T& operator()(const Matrix<T, Rows, Cols>& m, index row, index col) { + return m.at(row, col); + } + + const T& operator()(const Matrix<T, Rows, Cols>& m, index row, index col) { + return operator()(m, row, col); + } + }; + + template < + template<typename, std::size_t, std::size_t> typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct transpose : public visitor<Matrix, T, Rows, Cols> + { + T& operator()(const Matrix<T, Rows, Cols> m, index row, index col) { + // assert(col < Rows) + // assert(row < Cols) + return m.at(col, row); + } + }; +} + +namespace mm { + template < + template<typename, std::size_t, std::size_t> typename Matrix, + typename T, std::size_t Rows, std::size_t Cols + > + struct view + { + Matrix<T, Rows, Cols>& m; + // std::stack<std::unique_ptr<alg::visitor>> visitors; + std::unique_ptr<alg::visitor> visitor; + + T& at(index row, index col) { + return visitor(m, row, col); + } + + view& operator|=(const alg::visitor& other) { + // visitors.push(std::move(std::make_unique<alg::visitor>(other))); + visitor = std::make_unique<alg::visistor>(other); + } + }; + + view operator|(const view& left, const alg::visitor& right) { + return left |= right; + } +} |