From b22ab9a435dceb362cb9eab8ec195d908fd8d5e9 Mon Sep 17 00:00:00 2001
From: Nao Pross <naopross@thearcway.org>
Date: Sat, 2 Mar 2019 11:22:17 +0100
Subject: Update matrix test, add square matrix trace and fix comments

---
 include/mmmatrix.hpp | 97 ++++++++++++++++++++++++++++++++++------------------
 1 file changed, 64 insertions(+), 33 deletions(-)

(limited to 'include')

diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 6a31610..5285429 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -22,20 +22,24 @@ namespace mm {
     template<typename T, std::size_t Rows, std::size_t Cols>
     class basic_matrix;
 
+    /* specialization of basic_matrx for Cols = 1 */
+    template<typename T, std::size_t Rows>
+    class row_vec;
+
+    /* specialization of basic_matrx for Rows = 1 */
+    template<typename T, std::size_t Cols>
+    class col_vec;
+
+    /* shorter name for basic_matrix */
     template<typename T, std::size_t Rows, std::size_t Cols>
     class matrix;
 
+    /* specialization of basic_matrix for Rows == Cols */
     template<typename T, std::size_t N>
     class square_matrix;
 
     // template<typename T, std::size_t N>
     // class diag_matrix;
-
-    template<typename T, std::size_t Rows>
-    class row_vec;
-
-    template<typename T, std::size_t Cols>
-    class col_vec;
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
@@ -73,7 +77,7 @@ public:
 
     // mathematical operations
     virtual basic_matrix<T, Cols, Rows> transposed() const;
-    inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); }
+    inline basic_matrix<T, Cols, Rows> td() const { return transposed(); }
 
     // bool is_invertible() const;
     // basic_matrix<T, Rows, Cols> inverse() const;
@@ -102,8 +106,8 @@ public:
     }
 
 protected:
-    template<typename Iterator>
-    basic_matrix(Iterator begin, Iterator end);
+    template<typename ConstIterator>
+    basic_matrix(ConstIterator begin, ConstIterator end);
 
 private:
     std::array<T, Rows * Cols> data;
@@ -157,8 +161,10 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
 
 /* protected construtor */
 template<typename T, std::size_t Rows, std::size_t Cols>
-template<typename Iterator>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) {
+template<typename ConstIterator>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(
+    ConstIterator begin, ConstIterator end
+) {
     assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols)));
     std::copy(begin, end, data.begin());
 }
@@ -188,8 +194,8 @@ auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
         return data.at(index);
     } else {
         return row_vec<T, Rows>(
-            data.begin() + (index * Cols),
-            data.begin() + ((index + 1) * Cols) + 1
+            data.cbegin() + (index * Cols),
+            data.cbegin() + ((index + 1) * Cols) + 1
         );
     }
 }
@@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>
 }
 
 
-/* square matrix specializaiton */
-
-template<typename T, std::size_t N>
-class mm::square_matrix : public mm::basic_matrix<T, N, N> {
-public:
-    /// in place transpose
-    void transpose();  
-    inline void tr() { transpose(); }
-
-    /// in place inverse
-    void invert();
-};
-
-
-template<typename T, std::size_t N>
-void mm::square_matrix<T, N>::transpose() {
-    for (unsigned row = 0; row < N; row++)
-        for (unsigned col = 0; col < row; col++)
-            std::swap(this->at(row, col), this->at(col, row));
-}
 
 
 /* row vector specialization */
@@ -336,8 +322,53 @@ public:
     using mm::basic_matrix<T, 1, Cols>::basic_matrix;
 };
 
+/* general specialization (alias) */
 template<typename T, std::size_t Rows, std::size_t Cols>
 class mm::matrix : public mm::basic_matrix<T, Rows, Cols> {
 public:
     using mm::basic_matrix<T, Rows, Cols>::basic_matrix;
 };
+
+/* square matrix specializaiton */
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::basic_matrix<T, N, N> {
+public:
+    using mm::basic_matrix<T, N, N>::basic_matrix;
+
+    /// in place transpose
+    void transpose();  
+    inline void t() { transpose(); }
+
+    T trace();
+    inline T tr() { return trace(); }
+
+    /// in place inverse
+    void invert();
+
+    // get the identity of size N
+    static inline constexpr square_matrix<T, N> identity() {
+        square_matrix<T, N> i;
+        for (unsigned row = 0; row < N; row++)
+            for (unsigned col = 0; col < N; col++)
+                i.at(row, col) = (row == col) ? 1 : 0;
+
+        return i;
+    }
+};
+
+
+template<typename T, std::size_t N>
+void mm::square_matrix<T, N>::transpose() {
+    for (unsigned row = 0; row < N; row++)
+        for (unsigned col = 0; col < row; col++)
+            std::swap(this->at(row, col), this->at(col, row));
+}
+
+template<typename T, std::size_t N>
+T mm::square_matrix<T, N>::trace() {
+    T sum = 0;
+    for (unsigned i = 0; i < N; i++)
+        sum += this->at(i, i);
+
+    return sum;
+}
-- 
cgit v1.2.1