summaryrefslogtreecommitdiffstats
path: root/include/mm
diff options
context:
space:
mode:
authorNao Pross <naopross@thearcway.org>2019-03-01 17:00:55 +0100
committerNao Pross <naopross@thearcway.org>2019-03-01 17:11:40 +0100
commit2dd7c3bc4a6a49539e9847ec56a69cbf023e7e9b (patch)
tree18c5abf82f63080c80b6f6533e91f16244872087 /include/mm
parentChange storage for matrix to std::array, update matrix_example (diff)
downloadlibmm-2dd7c3bc4a6a49539e9847ec56a69cbf023e7e9b.tar.gz
libmm-2dd7c3bc4a6a49539e9847ec56a69cbf023e7e9b.zip
Fix matrix operator[] to allow M[j][k] and operator<< formatting
Diffstat (limited to '')
-rw-r--r--include/mmmatrix.hpp59
1 files changed, 31 insertions, 28 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 8f60312..6a31610 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -12,6 +12,7 @@
#pragma once
#include <iostream>
+#include <iomanip>
#include <cstring>
#include <cassert>
#include <initializer_list>
@@ -42,6 +43,9 @@ class mm::basic_matrix {
public:
using type = T;
+ template<typename U, std::size_t ORows, std::size_t OCols>
+ friend class mm::basic_matrix;
+
static constexpr std::size_t rows = Rows;
static constexpr std::size_t cols = Cols;
@@ -58,20 +62,17 @@ public:
template<std::size_t ORows, std::size_t OCols>
basic_matrix(const basic_matrix<T, ORows, OCols>& other);
- // copy or move from 2D array
- basic_matrix(const T (& values)[Rows][Cols]);
- basic_matrix(T (&& values)[Rows][Cols]);
-
// access data
T& at(std::size_t row, std::size_t col);
const T& at(std::size_t row, std::size_t col) const;
- auto&& operator[](std::size_t index);
+ // allows to access a matrix M at row j col k with M[j][k]
+ auto operator[](std::size_t index);
void swap_rows(std::size_t x, std::size_t y);
void swap_cols(std::size_t x, std::size_t y);
// mathematical operations
- basic_matrix<T, Cols, Rows> transposed() const;
+ virtual basic_matrix<T, Cols, Rows> transposed() const;
inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); }
// bool is_invertible() const;
@@ -79,7 +80,7 @@ public:
/// downcast to square matrix
- inline constexpr bool is_square() const { return (Rows == Cols); }
+ static inline constexpr bool is_square() { return (Rows == Cols); }
inline constexpr square_matrix<T, Rows> to_square() const {
static_assert(is_square());
return static_cast<square_matrix<T, Rows>>(*this);
@@ -87,19 +88,23 @@ public:
/// downcast to row_vector
- inline constexpr bool is_row_vec() const { return (Cols == 1); }
+ static inline constexpr bool is_row_vec() { return (Cols == 1); }
inline constexpr row_vec<T, Rows> to_row_vec() const {
static_assert(is_row_vec());
return static_cast<row_vec<T, Rows>>(*this);
}
/// downcast to col_vector
- inline constexpr bool is_col_vec() const { return (Rows == 1); }
+ static inline constexpr bool is_col_vec() { return (Rows == 1); }
inline constexpr col_vec<T, Cols> to_col_vec() const {
static_assert(is_col_vec());
return static_cast<col_vec<T, Cols>>(*this);
}
+protected:
+ template<typename Iterator>
+ basic_matrix(Iterator begin, Iterator end);
+
private:
std::array<T, Rows * Cols> data;
};
@@ -150,13 +155,13 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
this->at(row, col) = other.at(row, col);
}
+/* protected construtor */
template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols])
- : data(values) {}
-
-template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols])
- : data(std::forward<decltype(values)>(values)) {}
+template<typename Iterator>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) {
+ assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols)));
+ std::copy(begin, end, data.begin());
+}
/* member functions */
@@ -178,17 +183,15 @@ const T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) c
}
template<typename T, std::size_t Rows, std::size_t Cols>
-auto&& mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
- if constexpr (is_row_vec()) {
- static_assert(index < Rows);
- return data[index];
- } else if constexpr (is_col_vec()) {
- static_assert(index < Cols);
- return data[index];
+auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
+ if constexpr (is_row_vec() || is_col_vec()) {
+ return data.at(index);
+ } else {
+ return row_vec<T, Rows>(
+ data.begin() + (index * Cols),
+ data.begin() + ((index + 1) * Cols) + 1
+ );
}
-
- // TODO: fix
- // return row_vec<T, Rows>(std::move(data[index]));
}
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -283,14 +286,14 @@ mm::basic_matrix<T, Rows, Cols> operator-(
return a + (static_cast<T>(-1) * b);
}
-template<typename T, std::size_t Rows, std::size_t Cols>
+template<typename T, std::size_t Rows, std::size_t Cols, unsigned NumW = 3>
std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>& m) {
for (unsigned row = 0; row < Rows; row++) {
os << "[ ";
for (unsigned col = 0; col < (Cols -1); col++) {
- os << m.at(row, col) << ", ";
+ os << std::setw(NumW) << m.at(row, col) << ", ";
}
- os << m.at(row, (Cols -1)) << " ]\n";
+ os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n";
}
return os;