From 289b76e4a2177e320fe906fdbbbe3b20a2f11942 Mon Sep 17 00:00:00 2001 From: ancarola Date: Mon, 1 Jul 2019 00:00:56 +0200 Subject: The matrix library is compiling and all tested operations work fine. Next goals: - Implement optimisations for multiplication in K-diagonal - Add adjoint operation for complex matrices - Determinant - Algorithms: Gauss Jordan --- include/mm/debug.hpp | 78 ++++++++++++++++++++++++++++++++ include/mm/mmiterator.hpp | 39 +++++++++++++--- include/mm/mmmatrix.hpp | 111 ++++++++++++++++++++++++++++++---------------- test/matrix_example.cpp | 14 +++--- 4 files changed, 191 insertions(+), 51 deletions(-) create mode 100644 include/mm/debug.hpp diff --git a/include/mm/debug.hpp b/include/mm/debug.hpp new file mode 100644 index 0000000..0254744 --- /dev/null +++ b/include/mm/debug.hpp @@ -0,0 +1,78 @@ +#pragma once + +#ifndef __NPDEBUG__ +#define __NPDEBUG__ + +#include +#include + +#ifndef NDEBUG + #define __FILENAME__ (\ + __builtin_strrchr(__FILE__, '/') ? \ + __builtin_strrchr(__FILE__, '/') + 1 : __FILE__) + + #define npdebug_prep(); { \ + std::cerr << "[" << __FILENAME__ \ + << ":" << __LINE__ \ + << ", " << __func__ \ + << "] " ; \ + } + + #define npdebug(...); { \ + npdebug_prep(); \ + np::va_debug(__VA_ARGS__); \ + } + + namespace np { + template + inline void va_debug(Args&&... args) { + (std::cerr << ... << args) << std::endl; + } + + template + void range_debug(const T& t) { + range_debug("", t); + } + + template + void range_debug(const std::string& msg, const T& t) { + std::string out; + for (auto elem : t) + out += elem += ", "; + + npdebug(msg, out); + } + + template + T inspect(const T& t) { + npdebug(t); + return t; + } + + template + T inspect(const std::string& msg, const T& t) { + npdebug(msg, t); + return t; + } + } +#else + #define npdebug(...) {} + + namespace np { + template + inline void va_debug(Args&... args) {} + + template + inline void range_debug(const T& t) {} + + template + inline void range_debug(const std::string& msg, const T& t) {} + + template + T inspect(const T& t) { return t; } + + template + T inspect(const std::string& msg, const T& t) { return t; } + } +#endif // NDEBUG +#endif // __NPDEBUG__ diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp index c67b92b..7b3480e 100644 --- a/include/mm/mmiterator.hpp +++ b/include/mm/mmiterator.hpp @@ -1,5 +1,7 @@ #pragma once +#include "debug.hpp" + namespace mm::iter { template @@ -33,14 +35,14 @@ public: IterType operator++() { - IterType it = *this; + IterType it = cpy(); ++index; return it; } IterType operator--() { - IterType it = *this; + IterType it = cpy(); --index; return it; } @@ -48,13 +50,13 @@ public: IterType& operator++(int) { ++index; - return *this; + return ref(); } IterType& operator--(int) { --index; - return *this; + return ref(); } bool operator==(const IterType& other) const @@ -111,6 +113,9 @@ protected: const std::size_t position; // fixed index, negative too for diagonal iterator std::size_t index; // variable index + + virtual IterType& ref() = 0; + virtual IterType cpy() = 0; }; template @@ -118,12 +123,24 @@ class mm::iter::basic_iterator : public mm::iter::vector_iterator& ref() override + { + return *this; + } + + virtual mm::iter::basic_iterator cpy() override + { + return *this; + } + public: basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true) : mm::iter::vector_iterator, Grid> (A, pos, _index), direction(dir) { + //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir) + if (direction) assert(pos < Rows); else @@ -162,11 +179,21 @@ class mm::iter::diag_iterator : public mm::iter::vector_iterator& ref() override + { + return *this; + } + + virtual mm::iter::diag_iterator cpy() override + { + return *this; + } + public: - diag_iterator(Grid& A, signed long pos, std::size_t _index = 0) + diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0) : mm::iter::vector_iterator, Grid> - (A, static_cast(abs(pos)), _index), sign(pos >= 0) + (A, static_cast(labs(pos)), _index), sign(pos >= 0) { assert(this->position < N); } diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp index f88e90c..5eb0171 100644 --- a/include/mm/mmmatrix.hpp +++ b/include/mm/mmmatrix.hpp @@ -77,6 +77,9 @@ public: template friend class mm::iter::basic_iterator; + template + friend class mm::iter::diag_iterator; + //template //friend class mm::iter::basic_iterator>; @@ -183,11 +186,6 @@ void mm::basic_matrix::swap_cols(std::size_t x, std::size_t y) { template class mm::matrix { -protected: - - // shallow construction - matrix(std::shared_ptr> grid = nullptr, bool tr = false) : M(grid), transposed(tr) {} - public: //template @@ -228,7 +226,7 @@ public: * Transposition */ - matrix& transpose() + matrix& transpose_d() { transposed = !transposed; return *this; @@ -242,7 +240,7 @@ public: return m; } - inline matrix& t() + inline matrix& td() { return transpose(); } @@ -315,12 +313,12 @@ public: mm::matrix::vec_iterator operator[](std::size_t index) { - return mm::matrix::vec_iterator(*M, index, !transposed); + return mm::matrix::vec_iterator(*M, index, 0, !transposed); } mm::matrix::const_vec_iterator operator[](std::size_t index) const { - return mm::matrix::const_vec_iterator(*M, index, !transposed); + return mm::matrix::const_vec_iterator(*M, index, 0, !transposed); } /* @@ -358,7 +356,10 @@ protected: std::shared_ptr> M; - matrix shallow_cpy() + // shallow construction + matrix(std::shared_ptr> grid, bool tr = false) : M(grid), transposed(tr) {} + + matrix shallow_cpy() const { return matrix(M, transposed); } @@ -408,6 +409,7 @@ mm::matrix operator*( const mm::matrix& a, const mm::matrix& b ) { + // TODO, adjust asserts for transposed cases static_assert(P1 == P2, "invalid matrix multiplication"); assert(a.cols() == b.rows()); @@ -421,34 +423,6 @@ mm::matrix operator*( return result; } -// transposed multiplication -/*template -mm::matrix operator*( - const mm::matrix& a, - const mm::matrix& b -) { - static_assert(P1 == P2, "invalid matrix multiplication"); - assert(a.cols() == b.rows()); - - mm::matrix result; - mm::matrix bt = b.t(); // weak transposition - - for (unsigned row = 0; row < M; row++) - for (unsigned col = 0; col < N; col++) - result.at(row, col) = a[row] * bt[col]; // scalar product - - return result; -}*/ - - -/*template -void mm::square_matrix::transpose() { - for (unsigned row = 0; row < N; row++) - for (unsigned col = 0; col < row; col++) - std::swap(this->at(row, col), this->at(col, row)); -}*/ - - /* * Matrix operator << */ @@ -466,3 +440,64 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix& m) { return os; } + +/* + * Square matrix + */ + +template +class mm::square_matrix : public mm::matrix +{ +public: + + using mm::matrix::matrix; + + using diag_iterator = mm::iter::diag_iterator>; + + using const_diag_iterator = mm::iter::diag_iterator::type, N, typename std::add_const>::type>; + + T trace(); + inline T tr() { return trace(); } + + mm::square_matrix::diag_iterator diag_beg(int row = 0) + { + return diag_iterator(*(this->M), row, 0); + } + + mm::square_matrix::const_diag_iterator diag_end(int row = 0) const + { + return const_diag_iterator(*(this->M), row, N); + } + + // TODO, determinant + + /// in place inverse + // TODO, det != 0 + // TODO, use gauss jordan for invertible ones + //void invert();, TODO, section algorithm + + /* + * Generate the identity + */ + + static inline constexpr mm::square_matrix identity() { + mm::square_matrix i; + for (unsigned row = 0; row < N; row++) + for (unsigned col = 0; col < N; col++) + i.at(row, col) = (row == col) ? 1 : 0; + + return i; + } +}; + +template +T mm::square_matrix::trace() +{ + T sum = 0; + for (const auto& x : diag_beg()) + sum += x; + + return sum; +} + + diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp index 291feab..a835ee0 100644 --- a/test/matrix_example.cpp +++ b/test/matrix_example.cpp @@ -8,10 +8,12 @@ int main(int argc, char *argv[]) { mm::matrix a {{1, 2}, {3, 4}, {5, 6}}; mm::matrix b {{4, 3}, {9, 1}, {2, 5}}; mm::matrix c {{1, 2, 3, 4}, {5, 6, 7, 8}}; + auto ct = c.t(); std::cout << "a = \n" << a; std::cout << "b = \n" << b; std::cout << "c = \n" << c; + std::cout << "c^t = \n" << ct; std::cout << std::endl; // access elements @@ -20,7 +22,6 @@ int main(int argc, char *argv[]) { std::cout << "a[2][0] = " << a[2][0] << std::endl;; std::cout << std::endl; - /* // basic operations std::cout << "Basic operations" << std::endl; std::cout << "a + b = \n" << a + b; @@ -28,18 +29,18 @@ int main(int argc, char *argv[]) { std::cout << "a * c = \n" << a * c; std::cout << "a * 2 = \n" << a * 2; std::cout << "2 * a = \n" << 2 * a; - std::cout << "a.td() = \n" << a.td(); // or a.trasposed(); - std::cout << std::endl;*/ + std::cout << "a.td() = \n" << a.t(); // or a.trasposed(); + std::cout << std::endl; // special matrices - /*mm::square_matrix, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; + mm::square_matrix, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}}; std::cout << "Square matrix" << std::endl; std::cout << "f = \n" << f; std::cout << "tr(f) = " << f.tr(); //or f.trace() << std::endl; - mm::t_square_matrix, 2>& ft = f.t(); + auto ft = f.t(); std::cout << "after in place transpose f.t(), f = \n" << ft; std::cout << std::endl; @@ -47,8 +48,7 @@ int main(int argc, char *argv[]) { std::cout << "Identity matrix" << std::endl; std::cout << "I = \n" << identity; - std::cout << std::endl; */ - + std::cout << std::endl; return 0; } -- cgit v1.2.1