From 2dd7c3bc4a6a49539e9847ec56a69cbf023e7e9b Mon Sep 17 00:00:00 2001
From: Nao Pross <naopross@thearcway.org>
Date: Fri, 1 Mar 2019 17:00:55 +0100
Subject: Fix matrix operator[] to allow M[j][k] and operator<< formatting

---
 include/mmmatrix.hpp    | 59 ++++++++++++++++++++++++++-----------------------
 test/matrix_example.cpp |  4 ++++
 2 files changed, 35 insertions(+), 28 deletions(-)

diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 8f60312..6a31610 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -12,6 +12,7 @@
 #pragma once
 
 #include <iostream>
+#include <iomanip>
 #include <cstring>
 #include <cassert>
 #include <initializer_list>
@@ -42,6 +43,9 @@ class mm::basic_matrix {
 public:
     using type = T;
 
+    template<typename U, std::size_t ORows, std::size_t OCols>
+    friend class mm::basic_matrix;
+
     static constexpr std::size_t rows = Rows;
     static constexpr std::size_t cols = Cols;
 
@@ -58,20 +62,17 @@ public:
     template<std::size_t ORows, std::size_t OCols>
     basic_matrix(const basic_matrix<T, ORows, OCols>& other);
 
-    // copy or move from 2D array
-    basic_matrix(const T (& values)[Rows][Cols]);
-    basic_matrix(T (&& values)[Rows][Cols]);
-
     // access data
     T& at(std::size_t row, std::size_t col);
     const T& at(std::size_t row, std::size_t col) const;
-    auto&& operator[](std::size_t index);
+    // allows to access a matrix M at row j col k with M[j][k]
+    auto operator[](std::size_t index);
 
     void swap_rows(std::size_t x, std::size_t y);
     void swap_cols(std::size_t x, std::size_t y);
 
     // mathematical operations
-    basic_matrix<T, Cols, Rows> transposed() const;
+    virtual basic_matrix<T, Cols, Rows> transposed() const;
     inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); }
 
     // bool is_invertible() const;
@@ -79,7 +80,7 @@ public:
 
 
     /// downcast to square matrix
-    inline constexpr bool is_square() const { return (Rows == Cols); }
+    static inline constexpr bool is_square() { return (Rows == Cols); }
     inline constexpr square_matrix<T, Rows> to_square() const {
         static_assert(is_square());
         return static_cast<square_matrix<T, Rows>>(*this);
@@ -87,19 +88,23 @@ public:
 
 
     /// downcast to row_vector
-    inline constexpr bool is_row_vec() const { return (Cols == 1); }
+    static inline constexpr bool is_row_vec() { return (Cols == 1); }
     inline constexpr row_vec<T, Rows> to_row_vec() const {
         static_assert(is_row_vec());
         return static_cast<row_vec<T, Rows>>(*this);
     }
 
     /// downcast to col_vector
-    inline constexpr bool is_col_vec() const { return (Rows == 1); }
+    static inline constexpr bool is_col_vec() { return (Rows == 1); }
     inline constexpr col_vec<T, Cols> to_col_vec() const {
         static_assert(is_col_vec());
         return static_cast<col_vec<T, Cols>>(*this);
     }
 
+protected:
+    template<typename Iterator>
+    basic_matrix(Iterator begin, Iterator end);
+
 private:
     std::array<T, Rows * Cols> data;
 };
@@ -150,13 +155,13 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
             this->at(row, col) = other.at(row, col);
 }
 
+/* protected construtor */
 template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols])
-    : data(values) {}
-
-template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols])
-    : data(std::forward<decltype(values)>(values)) {}
+template<typename Iterator>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) {
+    assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols)));
+    std::copy(begin, end, data.begin());
+}
 
 
 /* member functions */
@@ -178,17 +183,15 @@ const T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) c
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
-auto&& mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
-    if constexpr (is_row_vec()) {
-        static_assert(index < Rows);
-        return data[index];
-    } else if constexpr (is_col_vec()) {
-        static_assert(index < Cols);
-        return data[index];
+auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
+    if constexpr (is_row_vec() || is_col_vec()) {
+        return data.at(index);
+    } else {
+        return row_vec<T, Rows>(
+            data.begin() + (index * Cols),
+            data.begin() + ((index + 1) * Cols) + 1
+        );
     }
-
-    // TODO: fix
-    // return row_vec<T, Rows>(std::move(data[index]));
 }
 
 template<typename T, std::size_t Rows, std::size_t Cols>
@@ -283,14 +286,14 @@ mm::basic_matrix<T, Rows, Cols> operator-(
     return a + (static_cast<T>(-1) * b);
 }
 
-template<typename T, std::size_t Rows, std::size_t Cols>
+template<typename T, std::size_t Rows, std::size_t Cols, unsigned NumW = 3>
 std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>& m) {
     for (unsigned row = 0; row < Rows; row++) {
         os << "[ ";
         for (unsigned col = 0; col < (Cols -1); col++) {
-            os << m.at(row, col) << ", ";
+            os << std::setw(NumW) << m.at(row, col) << ", ";
         }
-        os << m.at(row, (Cols -1)) << " ]\n";
+        os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n";
     }
 
     return os;
diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp
index 4cf1863..26aeede 100644
--- a/test/matrix_example.cpp
+++ b/test/matrix_example.cpp
@@ -12,6 +12,10 @@ int main(int argc, char *argv[]) {
     std::cout << "a = \n" << a;
     std::cout << "b = \n" << b;
     std::cout << "c = \n" << c;
+
+    // access elements
+    std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl;
+    std::cout << "a[2][0]   = " << a[2][0] << std::endl;;
     
     // basic operations
     std::cout << "a + b = \n" << a + b;
-- 
cgit v1.2.1