diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/mm/debug.hpp | 78 | ||||
-rw-r--r-- | include/mm/mmiterator.hpp | 39 | ||||
-rw-r--r-- | include/mm/mmmatrix.hpp | 111 |
3 files changed, 184 insertions, 44 deletions
diff --git a/include/mm/debug.hpp b/include/mm/debug.hpp new file mode 100644 index 0000000..0254744 --- /dev/null +++ b/include/mm/debug.hpp @@ -0,0 +1,78 @@ +#pragma once + +#ifndef __NPDEBUG__ +#define __NPDEBUG__ + +#include <iostream> +#include <sstream> + +#ifndef NDEBUG + #define __FILENAME__ (\ + __builtin_strrchr(__FILE__, '/') ? \ + __builtin_strrchr(__FILE__, '/') + 1 : __FILE__) + + #define npdebug_prep(); { \ + std::cerr << "[" << __FILENAME__ \ + << ":" << __LINE__ \ + << ", " << __func__ \ + << "] " ; \ + } + + #define npdebug(...); { \ + npdebug_prep(); \ + np::va_debug(__VA_ARGS__); \ + } + + namespace np { + template<typename... Args> + inline void va_debug(Args&&... args) { + (std::cerr << ... << args) << std::endl; + } + + template<typename T> + void range_debug(const T& t) { + range_debug("", t); + } + + template<typename T> + void range_debug(const std::string& msg, const T& t) { + std::string out; + for (auto elem : t) + out += elem += ", "; + + npdebug(msg, out); + } + + template<typename T> + T inspect(const T& t) { + npdebug(t); + return t; + } + + template<typename T> + T inspect(const std::string& msg, const T& t) { + npdebug(msg, t); + return t; + } + } +#else + #define npdebug(...) {} + + namespace np { + template<typename... Args> + inline void va_debug(Args&... args) {} + + template<typename T> + inline void range_debug(const T& t) {} + + template<typename T> + inline void range_debug(const std::string& msg, const T& t) {} + + template<typename T> + T inspect(const T& t) { return t; } + + template<typename T> + T inspect(const std::string& msg, const T& t) { return t; } + } +#endif // NDEBUG +#endif // __NPDEBUG__ diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index c67b92b..7b3480e 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -1,5 +1,7 @@ #pragma once +#include "debug.hpp" + namespace mm::iter { template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid> @@ -33,14 +35,14 @@ public: IterType operator++() { - IterType it = *this; + IterType it = cpy(); ++index; return it; } IterType operator--() { - IterType it = *this; + IterType it = cpy(); --index; return it; } @@ -48,13 +50,13 @@ public: IterType& operator++(int) { ++index; - return *this; + return ref(); } IterType& operator--(int) { --index; - return *this; + return ref(); } bool operator==(const IterType& other) const @@ -111,6 +113,9 @@ protected: const std::size_t position; // fixed index, negative too for diagonal iterator std::size_t index; // variable index + + virtual IterType& ref() = 0; + virtual IterType cpy() = 0; }; template<typename T, std::size_t Rows, std::size_t Cols, class Grid> @@ -118,12 +123,24 @@ class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, { bool direction; + virtual mm::iter::basic_iterator<T, Rows, Cols, Grid>& ref() override + { + return *this; + } + + virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> cpy() override + { + return *this; + } + public: basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true) : mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid> (A, pos, _index), direction(dir) { + //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir) + if (direction) assert(pos < Rows); else @@ -162,11 +179,21 @@ class mm::iter::diag_iterator : public mm::iter::vector_iterator<T, N, N, mm::it { bool sign; + virtual mm::iter::diag_iterator<T, N, Grid>& ref() override + { + return *this; + } + + virtual mm::iter::diag_iterator<T, N, Grid> cpy() override + { + return *this; + } + public: - diag_iterator(Grid& A, signed long pos, std::size_t _index = 0) + diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0) : mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid> - (A, static_cast<std::size_t>(abs(pos)), _index), sign(pos >= 0) + (A, static_cast<std::size_t>(labs(pos)), _index), sign(pos >= 0) { assert(this->position < N); } diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index f88e90c..5eb0171 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -77,6 +77,9 @@ public: template<typename U, std::size_t ORows, std::size_t OCols, class Grid> friend class mm::iter::basic_iterator; + template<typename U, std::size_t ON, class Grid> + friend class mm::iter::diag_iterator; + //template<typename U, std::size_t ORows, std::size_t OCols, class Grid> //friend class mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>; @@ -183,11 +186,6 @@ void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) { template<typename T, std::size_t Rows, std::size_t Cols> class mm::matrix { -protected: - - // shallow construction - matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid = nullptr, bool tr = false) : M(grid), transposed(tr) {} - public: //template<typename U, std::size_t ORows, std::size_t OCols> @@ -228,7 +226,7 @@ public: * Transposition */ - matrix<T, Rows, Cols>& transpose() + matrix<T, Rows, Cols>& transpose_d() { transposed = !transposed; return *this; @@ -242,7 +240,7 @@ public: return m; } - inline matrix<T, Rows, Cols>& t() + inline matrix<T, Rows, Cols>& td() { return transpose(); } @@ -315,12 +313,12 @@ public: mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index) { - return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, !transposed); + return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed); } mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const { - return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, !transposed); + return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed); } /* @@ -358,7 +356,10 @@ protected: std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> M; - matrix<T, Rows, Cols> shallow_cpy() + // shallow construction + matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {} + + matrix<T, Rows, Cols> shallow_cpy() const { return matrix<T, Rows, Cols>(M, transposed); } @@ -408,6 +409,7 @@ mm::matrix<T, M, N> operator*( const mm::matrix<T, M, P1>& a, const mm::matrix<T, P2, N>& b ) { + // TODO, adjust asserts for transposed cases static_assert(P1 == P2, "invalid matrix multiplication"); assert(a.cols() == b.rows()); @@ -421,34 +423,6 @@ mm::matrix<T, M, N> operator*( return result; } -// transposed multiplication -/*template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N> -mm::matrix<T, M, N> operator*( - const mm::matrix<T, P1, M>& a, - const mm::matrix<T, P2, N>& b -) { - static_assert(P1 == P2, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix<T, M, N> result; - mm::matrix<T, P2, N> bt = b.t(); // weak transposition - - for (unsigned row = 0; row < M; row++) - for (unsigned col = 0; col < N; col++) - result.at(row, col) = a[row] * bt[col]; // scalar product - - return result; -}*/ - - -/*template<typename T, std::size_t N> -void mm::square_matrix<T, N>::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -}*/ - - /* * Matrix operator << */ @@ -466,3 +440,64 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) { return os; } + +/* + * Square matrix + */ + +template<typename T, std::size_t N> +class mm::square_matrix : public mm::matrix<T, N, N> +{ +public: + + using mm::matrix<T, N, N>::matrix; + + using diag_iterator = mm::iter::diag_iterator<T, N, mm::basic_matrix<T, N, N>>; + + using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>; + + T trace(); + inline T tr() { return trace(); } + + mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0) + { + return diag_iterator(*(this->M), row, 0); + } + + mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const + { + return const_diag_iterator(*(this->M), row, N); + } + + // TODO, determinant + + /// in place inverse + // TODO, det != 0 + // TODO, use gauss jordan for invertible ones + //void invert();, TODO, section algorithm + + /* + * Generate the identity + */ + + static inline constexpr mm::square_matrix<T, N> identity() { + mm::square_matrix<T, N> i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + +template<typename T, std::size_t N> +T mm::square_matrix<T, N>::trace() +{ + T sum = 0; + for (const auto& x : diag_beg()) + sum += x; + + return sum; +} + + |