summaryrefslogtreecommitdiffstats
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/mm/debug.hpp78
-rw-r--r--include/mm/mmiterator.hpp39
-rw-r--r--include/mm/mmmatrix.hpp111
3 files changed, 184 insertions, 44 deletions
diff --git a/include/mm/debug.hpp b/include/mm/debug.hpp
new file mode 100644
index 0000000..0254744
--- /dev/null
+++ b/include/mm/debug.hpp
@@ -0,0 +1,78 @@
+#pragma once
+
+#ifndef __NPDEBUG__
+#define __NPDEBUG__
+
+#include <iostream>
+#include <sstream>
+
+#ifndef NDEBUG
+ #define __FILENAME__ (\
+ __builtin_strrchr(__FILE__, '/') ? \
+ __builtin_strrchr(__FILE__, '/') + 1 : __FILE__)
+
+ #define npdebug_prep(); { \
+ std::cerr << "[" << __FILENAME__ \
+ << ":" << __LINE__ \
+ << ", " << __func__ \
+ << "] " ; \
+ }
+
+ #define npdebug(...); { \
+ npdebug_prep(); \
+ np::va_debug(__VA_ARGS__); \
+ }
+
+ namespace np {
+ template<typename... Args>
+ inline void va_debug(Args&&... args) {
+ (std::cerr << ... << args) << std::endl;
+ }
+
+ template<typename T>
+ void range_debug(const T& t) {
+ range_debug("", t);
+ }
+
+ template<typename T>
+ void range_debug(const std::string& msg, const T& t) {
+ std::string out;
+ for (auto elem : t)
+ out += elem += ", ";
+
+ npdebug(msg, out);
+ }
+
+ template<typename T>
+ T inspect(const T& t) {
+ npdebug(t);
+ return t;
+ }
+
+ template<typename T>
+ T inspect(const std::string& msg, const T& t) {
+ npdebug(msg, t);
+ return t;
+ }
+ }
+#else
+ #define npdebug(...) {}
+
+ namespace np {
+ template<typename... Args>
+ inline void va_debug(Args&... args) {}
+
+ template<typename T>
+ inline void range_debug(const T& t) {}
+
+ template<typename T>
+ inline void range_debug(const std::string& msg, const T& t) {}
+
+ template<typename T>
+ T inspect(const T& t) { return t; }
+
+ template<typename T>
+ T inspect(const std::string& msg, const T& t) { return t; }
+ }
+#endif // NDEBUG
+#endif // __NPDEBUG__
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp
index c67b92b..7b3480e 100644
--- a/include/mm/mmiterator.hpp
+++ b/include/mm/mmiterator.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include "debug.hpp"
+
namespace mm::iter {
template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid>
@@ -33,14 +35,14 @@ public:
IterType operator++()
{
- IterType it = *this;
+ IterType it = cpy();
++index;
return it;
}
IterType operator--()
{
- IterType it = *this;
+ IterType it = cpy();
--index;
return it;
}
@@ -48,13 +50,13 @@ public:
IterType& operator++(int)
{
++index;
- return *this;
+ return ref();
}
IterType& operator--(int)
{
--index;
- return *this;
+ return ref();
}
bool operator==(const IterType& other) const
@@ -111,6 +113,9 @@ protected:
const std::size_t position; // fixed index, negative too for diagonal iterator
std::size_t index; // variable index
+
+ virtual IterType& ref() = 0;
+ virtual IterType cpy() = 0;
};
template<typename T, std::size_t Rows, std::size_t Cols, class Grid>
@@ -118,12 +123,24 @@ class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols,
{
bool direction;
+ virtual mm::iter::basic_iterator<T, Rows, Cols, Grid>& ref() override
+ {
+ return *this;
+ }
+
+ virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> cpy() override
+ {
+ return *this;
+ }
+
public:
basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true)
: mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid>
(A, pos, _index), direction(dir)
{
+ //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir)
+
if (direction)
assert(pos < Rows);
else
@@ -162,11 +179,21 @@ class mm::iter::diag_iterator : public mm::iter::vector_iterator<T, N, N, mm::it
{
bool sign;
+ virtual mm::iter::diag_iterator<T, N, Grid>& ref() override
+ {
+ return *this;
+ }
+
+ virtual mm::iter::diag_iterator<T, N, Grid> cpy() override
+ {
+ return *this;
+ }
+
public:
- diag_iterator(Grid& A, signed long pos, std::size_t _index = 0)
+ diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0)
: mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid>
- (A, static_cast<std::size_t>(abs(pos)), _index), sign(pos >= 0)
+ (A, static_cast<std::size_t>(labs(pos)), _index), sign(pos >= 0)
{
assert(this->position < N);
}
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index f88e90c..5eb0171 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -77,6 +77,9 @@ public:
template<typename U, std::size_t ORows, std::size_t OCols, class Grid>
friend class mm::iter::basic_iterator;
+ template<typename U, std::size_t ON, class Grid>
+ friend class mm::iter::diag_iterator;
+
//template<typename U, std::size_t ORows, std::size_t OCols, class Grid>
//friend class mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>;
@@ -183,11 +186,6 @@ void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) {
template<typename T, std::size_t Rows, std::size_t Cols>
class mm::matrix
{
-protected:
-
- // shallow construction
- matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid = nullptr, bool tr = false) : M(grid), transposed(tr) {}
-
public:
//template<typename U, std::size_t ORows, std::size_t OCols>
@@ -228,7 +226,7 @@ public:
* Transposition
*/
- matrix<T, Rows, Cols>& transpose()
+ matrix<T, Rows, Cols>& transpose_d()
{
transposed = !transposed;
return *this;
@@ -242,7 +240,7 @@ public:
return m;
}
- inline matrix<T, Rows, Cols>& t()
+ inline matrix<T, Rows, Cols>& td()
{
return transpose();
}
@@ -315,12 +313,12 @@ public:
mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index)
{
- return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, !transposed);
+ return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed);
}
mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const
{
- return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, !transposed);
+ return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed);
}
/*
@@ -358,7 +356,10 @@ protected:
std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> M;
- matrix<T, Rows, Cols> shallow_cpy()
+ // shallow construction
+ matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {}
+
+ matrix<T, Rows, Cols> shallow_cpy() const
{
return matrix<T, Rows, Cols>(M, transposed);
}
@@ -408,6 +409,7 @@ mm::matrix<T, M, N> operator*(
const mm::matrix<T, M, P1>& a,
const mm::matrix<T, P2, N>& b
) {
+ // TODO, adjust asserts for transposed cases
static_assert(P1 == P2, "invalid matrix multiplication");
assert(a.cols() == b.rows());
@@ -421,34 +423,6 @@ mm::matrix<T, M, N> operator*(
return result;
}
-// transposed multiplication
-/*template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N>
-mm::matrix<T, M, N> operator*(
- const mm::matrix<T, P1, M>& a,
- const mm::matrix<T, P2, N>& b
-) {
- static_assert(P1 == P2, "invalid matrix multiplication");
- assert(a.cols() == b.rows());
-
- mm::matrix<T, M, N> result;
- mm::matrix<T, P2, N> bt = b.t(); // weak transposition
-
- for (unsigned row = 0; row < M; row++)
- for (unsigned col = 0; col < N; col++)
- result.at(row, col) = a[row] * bt[col]; // scalar product
-
- return result;
-}*/
-
-
-/*template<typename T, std::size_t N>
-void mm::square_matrix<T, N>::transpose() {
- for (unsigned row = 0; row < N; row++)
- for (unsigned col = 0; col < row; col++)
- std::swap(this->at(row, col), this->at(col, row));
-}*/
-
-
/*
* Matrix operator <<
*/
@@ -466,3 +440,64 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) {
return os;
}
+
+/*
+ * Square matrix
+ */
+
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::matrix<T, N, N>
+{
+public:
+
+ using mm::matrix<T, N, N>::matrix;
+
+ using diag_iterator = mm::iter::diag_iterator<T, N, mm::basic_matrix<T, N, N>>;
+
+ using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>;
+
+ T trace();
+ inline T tr() { return trace(); }
+
+ mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0)
+ {
+ return diag_iterator(*(this->M), row, 0);
+ }
+
+ mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const
+ {
+ return const_diag_iterator(*(this->M), row, N);
+ }
+
+ // TODO, determinant
+
+ /// in place inverse
+ // TODO, det != 0
+ // TODO, use gauss jordan for invertible ones
+ //void invert();, TODO, section algorithm
+
+ /*
+ * Generate the identity
+ */
+
+ static inline constexpr mm::square_matrix<T, N> identity() {
+ mm::square_matrix<T, N> i;
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < N; col++)
+ i.at(row, col) = (row == col) ? 1 : 0;
+
+ return i;
+ }
+};
+
+template<typename T, std::size_t N>
+T mm::square_matrix<T, N>::trace()
+{
+ T sum = 0;
+ for (const auto& x : diag_beg())
+ sum += x;
+
+ return sum;
+}
+
+