diff options
author | ancarola <raffaele.ancarola@epfl.ch> | 2019-07-10 10:24:37 +0200 |
---|---|---|
committer | ancarola <raffaele.ancarola@epfl.ch> | 2019-07-10 10:24:37 +0200 |
commit | eb5cdb04efc9984d0937b65620606f8043dd1831 (patch) | |
tree | 46b750f4f2228e532e2041d873b37caf2940d15c | |
parent | Small correction on basic multiplication (diff) | |
download | libmm-eb5cdb04efc9984d0937b65620606f8043dd1831.tar.gz libmm-eb5cdb04efc9984d0937b65620606f8043dd1831.zip |
Implicit convertion to basic_vec to vec2 or vec3
-rw-r--r-- | include/mm/mmiterator.hpp | 6 | ||||
-rw-r--r-- | include/mm/mmmatrix.hpp | 198 | ||||
-rw-r--r-- | include/mm/mmvec.hpp | 21 | ||||
-rw-r--r-- | test/matrix_example.cpp | 6 |
4 files changed, 171 insertions, 60 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index ed406fe..4efe474 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -28,13 +28,13 @@ public: vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) : M(_M), position(pos), index(i) {} -#ifdef MM_IMPLICIT_CONVERSION_ITERATOR +//#ifdef MM_IMPLICIT_CONVERSION_ITERATOR operator T&() { - npdebug("Calling +") + //npdebug("Calling +") return *(*this); } -#endif +//#endif IterType operator++() { diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index 0b7a40b..facbb48 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -23,41 +23,30 @@ namespace mm { + /* basic grid structure */ + template<typename T, std::size_t Rows, std::size_t Cols> class basic_matrix; - /* specialisations */ + /* basic wrapper */ template<typename T, std::size_t Rows, std::size_t Cols> class matrix; // simple matrix format + + /* specialisations */ - /* specialization of basic_matrx for Cols = 1 */ - template<typename T, std::size_t Rows> - class row_vec; - - /* specialization of basic_matrx for Rows = 1 */ - template<typename T, std::size_t Cols> - class col_vec; // transposed version of row_vec + /* specialization of a matrix */ + template<typename T, std::size_t N> + class vector; // by default, set a column vector template<typename T, std::size_t N> class square_matrix; /* specialisation of a square_matrix for a sub-diagonal composed matrix */ - template<typename T, std::size_t N, std::size_t K = 0> - class diag_matrix; + template<typename T, std::size_t N, signed long ... Diags> + class multi_diag_matrix; } - -/*namespace mm { - - template<typename T, std::size_t N> - using diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, mm::basic_matrix<T, N, N>>; - - template<typename T, std::size_t N> - using const_diag_iterator = vector_iterator<typename std::add_const<T>::type, N, N, MM_DIAG_ITER, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; -}*/ - - /* * Matrix class, no access methods */ @@ -211,11 +200,6 @@ public: other.M = nullptr; } - // copy from another matrix - /*template<std::size_t ORows, std::size_t OCols> - matrix(const matrix<T, ORows, OCols>& other) - : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>(*other.M)), transposed(other.transposed) {} */ - matrix<T, Rows, Cols> operator=(const basic_matrix<T, Rows, Cols>& other) // deep copy { *M = *other.M; @@ -232,12 +216,9 @@ public: return *this; } - matrix<T, Rows, Cols> transpose() const + const matrix<T, Rows, Cols> transpose() const { - auto m = shallow_cpy(); - m.transposed = !transposed; - - return m; + return matrix<T, Rows, Cols>(M, !transposed); } inline matrix<T, Rows, Cols>& td() @@ -277,28 +258,24 @@ public: return static_cast<square_matrix<T, Rows>>(*this); } - /// downcast to row_vector - static inline constexpr bool is_row_vec() { return (Cols == 1); } - inline constexpr row_vec<T, Rows> to_row_vec() const { - static_assert(is_row_vec()); - return static_cast<row_vec<T, Rows>>(*this); - } - /// downcast to col_vector - static inline constexpr bool is_col_vec() { return (Rows == 1); } - inline constexpr col_vec<T, Cols> to_col_vec() const { - static_assert(is_col_vec()); - return static_cast<col_vec<T, Cols>>(*this); + static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); } + inline vector<T, Cols> to_vector() const { + if constexpr(Cols == 1) + return static_cast<vector<T, Rows>>(*this); + else if (Rows == 1) + return vector<T, Cols>(*this); // copy into column vector + } /* Accessors */ - T& at(std::size_t row, std::size_t col) + virtual T& at(std::size_t row, std::size_t col) { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } - const T& at(std::size_t row, std::size_t col) const + virtual const T& at(std::size_t row, std::size_t col) const { return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col]; } @@ -311,12 +288,12 @@ public: return (transposed) ? Rows : Cols; } - mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) + virtual mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) { return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed); } - mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const + virtual mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const { return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed); } @@ -358,12 +335,7 @@ protected: // shallow construction matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {} - - matrix<T, Rows, Cols> shallow_cpy() const - { - return matrix<T, Rows, Cols>(M, transposed); - } - + private: bool transposed; @@ -416,7 +388,7 @@ mm::matrix<T, M, N> operator*( mm::matrix<T, M, N> result; const mm::matrix<T, P2, N> bt = b.t(); // weak transposition - npdebug("Calling *") + //npdebug("Calling *") for (unsigned row = 0; row < M; row++) for (unsigned col = 0; col < N; col++) @@ -444,6 +416,33 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) { } /* + * Vector, TODO better manage column and row + */ + +template<typename T, std::size_t N> +class mm::vector : public mm::matrix<T, N, 1> +{ +public: + + using mm::matrix<T, N, 1>::matrix; + + vector(std::initializer_list<T> l) + : mm::matrix<T, N, 1>(l) {} +}; + +template<typename T> +mm::vector<T, 3> operator^(const mm::vector<T, 3>& v, const mm::vector<T, 3>& w) +{ + mm::vector<T, 3> out; + + out[0] = v[1] * w[2] - v[2] * w[2]; + out[1] = v[2] * w[0] - v[0] * w[2]; + out[2] = v[0] * w[1] - v[1] * w[0]; + + return out; +} + +/* * Square matrix */ @@ -458,15 +457,15 @@ public: using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; - T trace(); + virtual T trace(); inline T tr() { return trace(); } - mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) + virtual mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) { return diag_iterator(*(this->M), row, 0); } - mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const + virtual mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const { return const_diag_iterator(*(this->M), row, N); } @@ -502,4 +501,91 @@ T mm::square_matrix<T, N>::trace() return sum; } +// TODO, static assert, for all: Diags > -N, Diags < N +// TODO, force Diags to be ordered +template<typename T, std::size_t N, signed long ... Diags> +class mm::multi_diag_matrix : public mm::square_matrix<T, N> +{ + T& shared_zero = 0; + +public: + using mm::square_matrix<T, N>::square_matrix; + + // TODO, ordered case: dichotomy search O(log(M)) + // M = parameter pack size + static inline bool constexpr is_in(std::size_t i, std::size_t j) + { + auto t = std::make_tuple(Diags...); + + for(unsigned k(0); k < sizeof...(Diags); ++k) + if ((i - j) == std::get<k>(t)) + return true; + + return false; + } + + virtual T& at(std::size_t row, std::size_t col) override + { + if (is_in(row, col)) + return mm::square_matrix<T, N>::at(row, col); + + shared_zero = 0; + return shared_zero; + } + + virtual const T& at(std::size_t row, std::size_t col) const override + { + if (is_in(row, col)) + return mm::square_matrix<T, N>::at(row, col); + + return 0; + } + + + + // TODO, implement limited iterators +}; + +/*template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> +void constexpr diag_mult(const mm::multi_diag_matrix<T, N, Diags...>& a, + const mm::matrix<T, P, M>& b, mm::matrix<T, M, N>& result) +{ + static_assert(N == P && N == M, "invalid diagonal multiplication"); + + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } +} + +template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M> +mm::matrix<T, M, N> operator*( + const mm::multi_diag_matrix<T, N, Diags...>& a, + const mm::matrix<T, P, M>& b +) { + static_assert(N == P && N == M, "invalid matrix multiplication"); + assert(a.cols() == b.rows()); + + mm::matrix<T, M, N> result; + (( + auto d = a.diagonal(Diags); + if constexpr (Diags < 0) { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = -Diags; i < N; ++i) + result.at(i + Diags, k) += d[i + Diags] * b.at(i, k); + } else { + for (unsigned k = 0; k < M; ++k) + for (unsigned i = Diags; i < N; ++i) + result.at(i, k) += d[i - Diags] * b.at(i - Diags, k); + } + ) ...); + + return result; +}*/ diff --git a/include/mm/mmvec.hpp b/include/mm/mmvec.hpp index 1939388..da9040d 100644 --- a/include/mm/mmvec.hpp +++ b/include/mm/mmvec.hpp @@ -44,6 +44,27 @@ struct mm::basic_vec : public std::array<T, d> { using type = T; static constexpr std::size_t dimensions = d; + // convertions + static inline constexpr bool is_vec2() { + return d == 2; + } + + static inline constexpr bool is_vec3() { + return d == 3; + } + + operator mm::vec2<T>() + { + static_assert(is_vec2(), "Invalid cast to two dimensional vector"); + return static_cast<mm::vec2<T>>(*this); + } + + operator mm::vec3<T>() + { + static_assert(is_vec3(), "Invalid cast to three dimensional vector"); + return static_cast<mm::vec3<T>>(*this); + } + // TODO: template away these static constexpr T null_element = static_cast<T>(0); static constexpr T unit_element = static_cast<T>(1); diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index a835ee0..a3f0eac 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -32,7 +32,7 @@ int main(int argc, char *argv[]) { std::cout << "a.td() = \n" << a.t(); // or a.trasposed(); std::cout << std::endl; - // special matrices + // square matrix mm::square_matrix<std::complex<int>, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; std::cout << "Square matrix" << std::endl; @@ -50,5 +50,9 @@ int main(int argc, char *argv[]) { std::cout << "I = \n" << identity; std::cout << std::endl; + // vector + + // + return 0; } |