From a9828ad7c2b73527bcb4b28e49193e4523f21ad5 Mon Sep 17 00:00:00 2001
From: Nao Pross <naopross@thearcway.org>
Date: Sat, 23 Feb 2019 16:42:19 +0100
Subject: Change storage for matrix to std::array, update matrix_example

---
 include/mmmatrix.hpp    | 60 ++++++++++++++++++++++++-------------------------
 test/matrix_example.cpp | 16 +++++++++++--
 2 files changed, 43 insertions(+), 33 deletions(-)

diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index b973aba..8f60312 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -15,6 +15,7 @@
 #include <cstring>
 #include <cassert>
 #include <initializer_list>
+#include <array>
 
 namespace mm {
     template<typename T, std::size_t Rows, std::size_t Cols>
@@ -44,6 +45,8 @@ public:
     static constexpr std::size_t rows = Rows;
     static constexpr std::size_t cols = Cols;
 
+    basic_matrix();
+
     // from initializer_list
     basic_matrix(std::initializer_list<std::initializer_list<T>> l);
 
@@ -98,42 +101,35 @@ public:
     }
 
 private:
-    T data[Rows * Cols] = {};
+    std::array<T, Rows * Cols> data;
 };
 
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix() {
+    std::fill(data.begin(), data.end(), 0);
+}
 
 template<typename T, std::size_t Rows, std::size_t Cols>
 mm::basic_matrix<T, Rows, Cols>::basic_matrix(
     std::initializer_list<std::initializer_list<T>> l
 ) {
     assert(l.size() == Rows);
+    auto data_it = data.begin();
 
-    auto row_it = l.begin();
-    for (unsigned row = 0; (row < Rows) && (row_it != l.end()); row++) {
-        assert((*row_it).size() == Cols);
-
-        auto col_it = (*row_it).begin();
-        for (unsigned col = 0; (col < Cols) && (col_it != (*row_it).end()); col++) {
-            this->at(row, col) = *col_it;
-            ++col_it;
-        }
-        ++row_it;
+    for (auto&& row : l) {
+        data_it = std::copy(row.begin(), row.end(), data_it);
     }
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
 mm::basic_matrix<T, Rows, Cols>::basic_matrix(
     const mm::basic_matrix<T, Rows, Cols>& other
-) {
-    std::memcpy(&data, &other.data, sizeof(data));
-}
+) : data(other.data) {}
 
 template<typename T, std::size_t Rows, std::size_t Cols>
 mm::basic_matrix<T, Rows, Cols>::basic_matrix(
     mm::basic_matrix<T, Rows, Cols>&& other
-) {
-    data = other.data;
-}
+) : data(std::forward<decltype(other.data)>(other.data)) {}
 
 template<typename T, std::size_t Rows, std::size_t Cols>
 template<std::size_t ORows, std::size_t OCols>
@@ -148,18 +144,19 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
         "cannot copy a larger matrix into a smaller one"
     );
 
-    std::memcpy(&data, &other.data, sizeof(data));
+    std::fill(data.begin(), data.end(), 0);
+    for (unsigned row = 0; row < Rows; row++)
+        for (unsigned col = 0; col < Cols; col++)
+            this->at(row, col) = other.at(row, col);
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols]) {
-    std::memcpy(&data, &values, sizeof(data));
-}
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols])
+    : data(values) {}
 
 template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols]) {
-    data = values;
-}
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols])
+    : data(std::forward<decltype(values)>(values)) {}
 
 
 /* member functions */
@@ -218,7 +215,7 @@ mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() const {
 
     for (unsigned row = 0; row < M; row++)
         for (unsigned col = 0; col < N; col++)
-            result.at(row, col) = this->at(row, col);
+            result.at(col, row) = this->at(row, col);
 
     return result;
 }
@@ -234,7 +231,7 @@ mm::basic_matrix<T, Rows, Cols> operator+(
 
     for (unsigned row = 0; row < Rows; row++)
         for (unsigned col = 0; col < Cols; col++)
-            result.at(row, col) = a.at(row, col) + a.at(row, col);
+            result.at(row, col) = a.at(row, col) + b.at(row, col);
     
     return result;
 }
@@ -260,17 +257,18 @@ mm::basic_matrix<T, Rows, Cols> operator*(
     return m * scalar;
 }
 
-template<typename T, std::size_t M, std::size_t P, std::size_t N>
+template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N>
 mm::basic_matrix<T, M, N> operator*(
-    const mm::basic_matrix<T, M, P>& a,
-    const mm::basic_matrix<T, P, N>& b
+    const mm::basic_matrix<T, M, P1>& a,
+    const mm::basic_matrix<T, P2, N>& b
 ) {
+    static_assert(P1 == P2, "invalid matrix multiplication");
     mm::basic_matrix<T, M, N>  result;
 
     // TODO: use a more efficient algorithm
     for (unsigned row = 0; row < M; row++)
         for (unsigned col = 0; col < N; col++)
-            for (int k = 0; k < P; k++)
+            for (unsigned k = 0; k < P1; k++)
                 result.at(row, col) = a.at(row, k) * b.at(k, col);
 
     return result;
@@ -282,7 +280,7 @@ mm::basic_matrix<T, Rows, Cols> operator-(
     const mm::basic_matrix<T, Rows, Cols>& a,
     const mm::basic_matrix<T, Rows, Cols>& b
 ) {
-    return a + static_cast<T>(-1) * b;
+    return a + (static_cast<T>(-1) * b);
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp
index 3fb7c78..4cf1863 100644
--- a/test/matrix_example.cpp
+++ b/test/matrix_example.cpp
@@ -5,9 +5,21 @@
 
 int main(int argc, char *argv[]) {
 	std::cout << "MxN dimensional (int) matrices" << std::endl;
-    mm::matrix<int, 3, 2> m {{1, 2}, {3, 4}, {5, 6}};
+    mm::matrix<int, 3, 2> a {{1, 2}, {3, 4}, {5, 6}};
+    mm::matrix<int, 3, 2> b {{4, 3}, {9, 1}, {2, 5}};
+    mm::matrix<int, 2, 4> c {{1, 2, 3, 4}, {5, 6, 7, 8}};
 
-    std::cout << m;
+    std::cout << "a = \n" << a;
+    std::cout << "b = \n" << b;
+    std::cout << "c = \n" << c;
+    
+    // basic operations
+    std::cout << "a + b = \n" << a + b;
+    std::cout << "a - b = \n" << a - b;
+    std::cout << "a * c = \n" << a * c;
+    std::cout << "a * 2 = \n" << a * 2;
+    std::cout << "2 * a = \n" << 2 * a;
+    std::cout << "tr(a) = \n" << a.trd();
 
 	return 0;
 }
-- 
cgit v1.2.1