summaryrefslogtreecommitdiffstats
path: root/include/mm
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--include/mmmatrix.hpp60
1 files changed, 29 insertions, 31 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index b973aba..8f60312 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -15,6 +15,7 @@
#include <cstring>
#include <cassert>
#include <initializer_list>
+#include <array>
namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -44,6 +45,8 @@ public:
static constexpr std::size_t rows = Rows;
static constexpr std::size_t cols = Cols;
+ basic_matrix();
+
// from initializer_list
basic_matrix(std::initializer_list<std::initializer_list<T>> l);
@@ -98,42 +101,35 @@ public:
}
private:
- T data[Rows * Cols] = {};
+ std::array<T, Rows * Cols> data;
};
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix() {
+ std::fill(data.begin(), data.end(), 0);
+}
template<typename T, std::size_t Rows, std::size_t Cols>
mm::basic_matrix<T, Rows, Cols>::basic_matrix(
std::initializer_list<std::initializer_list<T>> l
) {
assert(l.size() == Rows);
+ auto data_it = data.begin();
- auto row_it = l.begin();
- for (unsigned row = 0; (row < Rows) && (row_it != l.end()); row++) {
- assert((*row_it).size() == Cols);
-
- auto col_it = (*row_it).begin();
- for (unsigned col = 0; (col < Cols) && (col_it != (*row_it).end()); col++) {
- this->at(row, col) = *col_it;
- ++col_it;
- }
- ++row_it;
+ for (auto&& row : l) {
+ data_it = std::copy(row.begin(), row.end(), data_it);
}
}
template<typename T, std::size_t Rows, std::size_t Cols>
mm::basic_matrix<T, Rows, Cols>::basic_matrix(
const mm::basic_matrix<T, Rows, Cols>& other
-) {
- std::memcpy(&data, &other.data, sizeof(data));
-}
+) : data(other.data) {}
template<typename T, std::size_t Rows, std::size_t Cols>
mm::basic_matrix<T, Rows, Cols>::basic_matrix(
mm::basic_matrix<T, Rows, Cols>&& other
-) {
- data = other.data;
-}
+) : data(std::forward<decltype(other.data)>(other.data)) {}
template<typename T, std::size_t Rows, std::size_t Cols>
template<std::size_t ORows, std::size_t OCols>
@@ -148,18 +144,19 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
"cannot copy a larger matrix into a smaller one"
);
- std::memcpy(&data, &other.data, sizeof(data));
+ std::fill(data.begin(), data.end(), 0);
+ for (unsigned row = 0; row < Rows; row++)
+ for (unsigned col = 0; col < Cols; col++)
+ this->at(row, col) = other.at(row, col);
}
template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols]) {
- std::memcpy(&data, &values, sizeof(data));
-}
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols])
+ : data(values) {}
template<typename T, std::size_t Rows, std::size_t Cols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols]) {
- data = values;
-}
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols])
+ : data(std::forward<decltype(values)>(values)) {}
/* member functions */
@@ -218,7 +215,7 @@ mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() const {
for (unsigned row = 0; row < M; row++)
for (unsigned col = 0; col < N; col++)
- result.at(row, col) = this->at(row, col);
+ result.at(col, row) = this->at(row, col);
return result;
}
@@ -234,7 +231,7 @@ mm::basic_matrix<T, Rows, Cols> operator+(
for (unsigned row = 0; row < Rows; row++)
for (unsigned col = 0; col < Cols; col++)
- result.at(row, col) = a.at(row, col) + a.at(row, col);
+ result.at(row, col) = a.at(row, col) + b.at(row, col);
return result;
}
@@ -260,17 +257,18 @@ mm::basic_matrix<T, Rows, Cols> operator*(
return m * scalar;
}
-template<typename T, std::size_t M, std::size_t P, std::size_t N>
+template<typename T, std::size_t M, std::size_t P1, std::size_t P2, std::size_t N>
mm::basic_matrix<T, M, N> operator*(
- const mm::basic_matrix<T, M, P>& a,
- const mm::basic_matrix<T, P, N>& b
+ const mm::basic_matrix<T, M, P1>& a,
+ const mm::basic_matrix<T, P2, N>& b
) {
+ static_assert(P1 == P2, "invalid matrix multiplication");
mm::basic_matrix<T, M, N> result;
// TODO: use a more efficient algorithm
for (unsigned row = 0; row < M; row++)
for (unsigned col = 0; col < N; col++)
- for (int k = 0; k < P; k++)
+ for (unsigned k = 0; k < P1; k++)
result.at(row, col) = a.at(row, k) * b.at(k, col);
return result;
@@ -282,7 +280,7 @@ mm::basic_matrix<T, Rows, Cols> operator-(
const mm::basic_matrix<T, Rows, Cols>& a,
const mm::basic_matrix<T, Rows, Cols>& b
) {
- return a + static_cast<T>(-1) * b;
+ return a + (static_cast<T>(-1) * b);
}
template<typename T, std::size_t Rows, std::size_t Cols>