From eb5cdb04efc9984d0937b65620606f8043dd1831 Mon Sep 17 00:00:00 2001 From: ancarola Date: Wed, 10 Jul 2019 10:24:37 +0200 Subject: Implicit convertion to basic_vec to vec2 or vec3 --- include/mm/mmiterator.hpp | 6 +- include/mm/mmmatrix.hpp | 198 +++++++++++++++++++++++++++++++++------------- include/mm/mmvec.hpp | 21 +++++ test/matrix_example.cpp | 6 +- 4 files changed, 171 insertions(+), 60 deletions(-) diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index ed406fe..4efe474 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -28,13 +28,13 @@ public: vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) : M(_M), position(pos), index(i) {} -#ifdef MM_IMPLICIT_CONVERSION_ITERATOR +//#ifdef MM_IMPLICIT_CONVERSION_ITERATOR operator T&() { - npdebug("Calling +") + //npdebug("Calling +") return *(*this); } -#endif +//#endif IterType operator++() { diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 0b7a40b..facbb48 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -23,41 +23,30 @@ namespace mm { + /* basic grid structure */ + template class basic_matrix; - /* specialisations */ + /* basic wrapper */ template class matrix; // simple matrix format + + /* specialisations */ - /* specialization of basic_matrx for Cols = 1 */ - template - class row_vec; - - /* specialization of basic_matrx for Rows = 1 */ - template - class col_vec; // transposed version of row_vec + /* specialization of a matrix */ + template + class vector; // by default, set a column vector template class square_matrix; /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template - class diag_matrix; + template + class multi_diag_matrix; } - -/*namespace mm { - - template - using diag_iterator = vector_iterator>; - - template - using const_diag_iterator = vector_iterator::type, N, N, MM_DIAG_ITER, typename std::add_const>::type>; -}*/ - - /* * Matrix class, no access methods */ @@ -211,11 +200,6 @@ public: other.M = nullptr; } - // copy from another matrix - /*template - matrix(const matrix& other) - : M(std::make_shared(*other.M)), transposed(other.transposed) {} */ - matrix operator=(const basic_matrix& other) // deep copy { *M = *other.M; @@ -232,12 +216,9 @@ public: return *this; } - matrix transpose() const + const matrix transpose() const { - auto m = shallow_cpy(); - m.transposed = !transposed; - - return m; + return matrix(M, !transposed); } inline matrix& td() @@ -277,28 +258,24 @@ public: return static_cast>(*this); } - /// downcast to row_vector - static inline constexpr bool is_row_vec() { return (Cols == 1); } - inline constexpr row_vec to_row_vec() const { - static_assert(is_row_vec()); - return static_cast>(*this); - } - /// downcast to col_vector - static inline constexpr bool is_col_vec() { return (Rows == 1); } - inline constexpr col_vec to_col_vec() const { - static_assert(is_col_vec()); - return static_cast>(*this); + static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } + inline vector to_vector() const { + if constexpr(Cols == 1) + return static_cast>(*this); + else if (Rows == 1) + return vector(*this); // copy into column vector + } /* Accessors */ - T& at(std::size_t row, std::size_t col) + virtual T& at(std::size_t row, std::size_t col) { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } - const T& at(std::size_t row, std::size_t col) const + virtual const T& at(std::size_t row, std::size_t col) const { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } @@ -311,12 +288,12 @@ public: return (transposed) ? Rows : Cols; } - mm::matrix::vec_iterator operator[](std::size_t index) + virtual mm::matrix::vec_iterator operator[](std::size_t index) { return mm::matrix::vec_iterator(*M, index, 0, !transposed); } - mm::matrix::const_vec_iterator operator[](std::size_t index) const + virtual mm::matrix::const_vec_iterator operator[](std::size_t index) const { return mm::matrix::const_vec_iterator(*M, index, 0, !transposed); } @@ -358,12 +335,7 @@ protected: // shallow construction matrix(std::shared_ptr> grid, bool tr = false) : M(grid), transposed(tr) {} - - matrix shallow_cpy() const - { - return matrix(M, transposed); - } - + private: bool transposed; @@ -416,7 +388,7 @@ mm::matrix operator*( mm::matrix result; const mm::matrix bt = b.t(); // weak transposition - npdebug("Calling *") + //npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) @@ -443,6 +415,33 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix& m) { return os; } +/* + * Vector, TODO better manage column and row + */ + +template +class mm::vector : public mm::matrix +{ +public: + + using mm::matrix::matrix; + + vector(std::initializer_list l) + : mm::matrix(l) {} +}; + +template +mm::vector operator^(const mm::vector& v, const mm::vector& w) +{ + mm::vector out; + + out[0] = v[1] * w[2] - v[2] * w[2]; + out[1] = v[2] * w[0] - v[0] * w[2]; + out[2] = v[0] * w[1] - v[1] * w[0]; + + return out; +} + /* * Square matrix */ @@ -458,15 +457,15 @@ public: using const_diag_iterator = mm::iter::diag_iterator::type, N, typename std::add_const>::type>; - T trace(); + virtual T trace(); inline T tr() { return trace(); } - mm::square_matrix::diag_iterator diag_beg(int row = 0) + virtual mm::square_matrix::diag_iterator diag_beg(int row = 0) { return diag_iterator(*(this->M), row, 0); } - mm::square_matrix::const_diag_iterator diag_end(int row = 0) const + virtual mm::square_matrix::const_diag_iterator diag_end(int row = 0) const { return const_diag_iterator(*(this->M), row, N); } @@ -502,4 +501,91 @@ T mm::square_matrix::trace() return sum; } +// TODO, static assert, for all: Diags > -N, Diags < N +// TODO, force Diags to be ordered +template +class mm::multi_diag_matrix : public mm::square_matrix +{ + T& shared_zero = 0; + +public: + using mm::square_matrix::square_matrix; + + // TODO, ordered case: dichotomy search O(log(M)) + // M = parameter pack size + static inline bool constexpr is_in(std::size_t i, std::size_t j) + { + auto t = std::make_tuple(Diags...); + + for(unsigned k(0); k < sizeof...(Diags); ++k) + if ((i - j) == std::get(t)) + return true; + + return false; + } + + virtual T& at(std::size_t row, std::size_t col) override + { + if (is_in(row, col)) + return mm::square_matrix::at(row, col); + + shared_zero = 0; + return shared_zero; + } + + virtual const T& at(std::size_t row, std::size_t col) const override + { + if (is_in(row, col)) + return mm::square_matrix::at(row, col); + + return 0; + } + + + + // TODO, implement limited iterators +}; + +/*template +void constexpr diag_mult(const mm::multi_diag_matrix& a, + const mm::matrix& b, mm::matrix& result) +{ + static_assert(N == P && N == M, "invalid diagonal multiplication"); + + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } +} + +template +mm::matrix operator*( + const mm::multi_diag_matrix& a, + const mm::matrix& b +) { + static_assert(N == P && N == M, "invalid matrix multiplication"); + assert(a.cols() == b.rows()); + + mm::matrix result; + (( + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } + ) ...); + + return result; +}*/ diff --git a/include/mm/mmvec.hpp b/include/mm/mmvec.hpp index 1939388..da9040d 100644 --- a/include/mm/mmvec.hpp +++ b/include/mm/mmvec.hpp @@ -44,6 +44,27 @@ struct mm::basic_vec : public std::array { using type = T; static constexpr std::size_t dimensions = d; + // convertions + static inline constexpr bool is_vec2() { + return d == 2; + } + + static inline constexpr bool is_vec3() { + return d == 3; + } + + operator mm::vec2() + { + static_assert(is_vec2(), "Invalid cast to two dimensional vector"); + return static_cast>(*this); + } + + operator mm::vec3() + { + static_assert(is_vec3(), "Invalid cast to three dimensional vector"); + return static_cast>(*this); + } + // TODO: template away these static constexpr T null_element = static_cast(0); static constexpr T unit_element = static_cast(1); diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index a835ee0..a3f0eac 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -32,7 +32,7 @@ int main(int argc, char *argv[]) { std::cout << "a.td() = \n" << a.t(); // or a.trasposed(); std::cout << std::endl; - // special matrices + // square matrix mm::square_matrix, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; std::cout << "Square matrix" << std::endl; @@ -50,5 +50,9 @@ int main(int argc, char *argv[]) { std::cout << "I = \n" << identity; std::cout << std::endl; + // vector + + // + return 0; } -- cgit v1.2.1