summaryrefslogtreecommitdiffstats
path: root/include/mm/mmmatrix.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mm/mmmatrix.hpp')
-rw-r--r--include/mm/mmmatrix.hpp111
1 files changed, 73 insertions, 38 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index f88e90c..5eb0171 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -77,6 +77,9 @@ public:
template<typename U, std::size_t ORows, std::size_t OCols, class Grid>
friend class mm::iter::basic_iterator;
+ template<typename U, std::size_t ON, class Grid>
+ friend class mm::iter::diag_iterator;
+
//template<typename U, std::size_t ORows, std::size_t OCols, class Grid>
//friend class mm::iter::basic_iterator<T, Rows, Cols, mm::basic_matrix<T, Rows, Cols>>;
@@ -183,11 +186,6 @@ void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) {
template<typename T, std::size_t Rows, std::size_t Cols>
class mm::matrix
{
-protected:
-
- // shallow construction
- matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid = nullptr, bool tr = false) : M(grid), transposed(tr) {}
-
public:
//template<typename U, std::size_t ORows, std::size_t OCols>
@@ -228,7 +226,7 @@ public:
* Transposition
*/
- matrix<T, Rows, Cols>& transpose()
+ matrix<T, Rows, Cols>& transpose_d()
{
transposed = !transposed;
return *this;
@@ -242,7 +240,7 @@ public:
return m;
}
- inline matrix<T, Rows, Cols>& t()
+ inline matrix<T, Rows, Cols>& td()
{
return transpose();
}
@@ -315,12 +313,12 @@ public:
mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index)
{
- return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, !transposed);
+ return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed);
}
mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const
{
- return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, !transposed);
+ return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed);
}
/*
@@ -358,7 +356,10 @@ protected:
std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> M;
- matrix<T, Rows, Cols> shallow_cpy()
+ // shallow construction
+ matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {}
+
+ matrix<T, Rows, Cols> shallow_cpy() const
{
return matrix<T, Rows, Cols>(M, transposed);
}
@@ -408,6 +409,7 @@ mm::matrix<T, M, N> operator*(
const mm::matrix<T, M, P1>& a,
const mm::matrix<T, P2, N>& b
) {
+ // TODO, adjust asserts for transposed cases
static_assert(P1 == P2, "invalid matrix multiplication");
assert(a.cols() == b.rows());
@@ -421,34 +423,6 @@ mm::matrix<T, M, N> operator*(
return result;
}
-// transposed multiplication
-/*template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N>
-mm::matrix<T, M, N> operator*(
- const mm::matrix<T, P1, M>& a,
- const mm::matrix<T, P2, N>& b
-) {
- static_assert(P1 == P2, "invalid matrix multiplication");
- assert(a.cols() == b.rows());
-
- mm::matrix<T, M, N> result;
- mm::matrix<T, P2, N> bt = b.t(); // weak transposition
-
- for (unsigned row = 0; row < M; row++)
- for (unsigned col = 0; col < N; col++)
- result.at(row, col) = a[row] * bt[col]; // scalar product
-
- return result;
-}*/
-
-
-/*template<typename T, std::size_t N>
-void mm::square_matrix<T, N>::transpose() {
- for (unsigned row = 0; row < N; row++)
- for (unsigned col = 0; col < row; col++)
- std::swap(this->at(row, col), this->at(col, row));
-}*/
-
-
/*
* Matrix operator <<
*/
@@ -466,3 +440,64 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) {
return os;
}
+
+/*
+ * Square matrix
+ */
+
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::matrix<T, N, N>
+{
+public:
+
+ using mm::matrix<T, N, N>::matrix;
+
+ using diag_iterator = mm::iter::diag_iterator<T, N, mm::basic_matrix<T, N, N>>;
+
+ using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>;
+
+ T trace();
+ inline T tr() { return trace(); }
+
+ mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0)
+ {
+ return diag_iterator(*(this->M), row, 0);
+ }
+
+ mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const
+ {
+ return const_diag_iterator(*(this->M), row, N);
+ }
+
+ // TODO, determinant
+
+ /// in place inverse
+ // TODO, det != 0
+ // TODO, use gauss jordan for invertible ones
+ //void invert();, TODO, section algorithm
+
+ /*
+ * Generate the identity
+ */
+
+ static inline constexpr mm::square_matrix<T, N> identity() {
+ mm::square_matrix<T, N> i;
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < N; col++)
+ i.at(row, col) = (row == col) ? 1 : 0;
+
+ return i;
+ }
+};
+
+template<typename T, std::size_t N>
+T mm::square_matrix<T, N>::trace()
+{
+ T sum = 0;
+ for (const auto& x : diag_beg())
+ sum += x;
+
+ return sum;
+}
+
+