summaryrefslogtreecommitdiffstats
path: root/include/mmmatrix.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mmmatrix.hpp')
-rw-r--r--include/mmmatrix.hpp101
1 files changed, 95 insertions, 6 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 7cb9359..d18f3f1 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -12,16 +12,27 @@
#pragma once
#include <iostream>
+#include <cstring>
namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
+
+ template<typename T, std::size_t Rows, std::size_t Cols>
+ class matrix;
+
+ template<typename T, std::size_t N>
+ class square_matrix;
+
+ // template<typename T, std::size_t N>
+ // class diag_matrix;
+
template<typename T, std::size_t Rows>
- using row_vec = basic_matrix<T, Rows, 1>;
+ class row_vec;
template<typename T, std::size_t Cols>
- using col_vec = basic_matrix<T, 1, Cols>;
+ class col_vec;
}
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -38,22 +49,45 @@ public:
template<std::size_t ORows, std::size_t OCols>
basic_matrix(const basic_matrix<T, ORows, OCols>& other);
+ basic_matrix(const T (& values)[Rows][Cols]);
+ basic_matrix(T (&& values)[Rows][Cols]);
+
// access data
T& at(std::size_t row, std::size_t col);
+ auto&& operator[](std::size_t index);
void swap_rows(std::size_t x, std::size_t y);
void swap_cols(std::size_t x, std::size_t y);
// mathematical operations
basic_matrix<T, Cols, Rows> transposed();
- inline basic_matrix<T, Cols, Rows> t() { return transposed(); }
+ inline basic_matrix<T, Cols, Rows> trd() { return transposed(); }
// bool is_invertible();
// bool invert();
// basic_matrix<T, Rows, Cols> inverse();
- inline constexpr bool is_square() {
- return (Rows == Cols);
+
+ /// downcast to square matrix
+ inline constexpr bool is_square() { return (Rows == Cols); }
+ inline constexpr square_matrix<T, Rows> to_square() {
+ static_assert(is_square());
+ return static_cast<square_matrix<T, Rows>>(*this);
+ }
+
+
+ /// downcast to row_vector
+ inline constexpr bool is_row_vec() { return (Cols == 1); }
+ inline constexpr row_vec<T, Rows> to_row_vec() {
+ static_assert(is_row_vec());
+ return static_cast<row_vec<T, Rows>>(*this);
+ }
+
+ /// downcast to col_vector
+ inline constexpr bool is_col_vec() { return (Rows == 1); }
+ inline constexpr col_vec<T, Cols> to_col_vec() {
+ static_assert(is_col_vec());
+ return static_cast<col_vec<T, Cols>>(*this);
}
private:
@@ -89,6 +123,16 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(const mm::basic_matrix<T, ORows, O
data[row][col] = other.data[row][col];
}
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(const T (& values)[Rows][Cols]) {
+ std::memcpy(&data, &values, sizeof(data));
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(T (&& values)[Rows][Cols]) {
+ data = values;
+}
+
/* member functions */
@@ -101,6 +145,16 @@ T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) {
}
template<typename T, std::size_t Rows, std::size_t Cols>
+auto&& mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
+ if constexpr (is_row_vec())
+ return data[0][index];
+ else if constexpr (is_col_vec())
+ return data[index][0];
+
+ return row_vec<T, Rows>(std::move(data[index]));
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
void mm::basic_matrix<T, Rows, Cols>::swap_rows(std::size_t x, std::size_t y) {
if (x == y)
return;
@@ -124,7 +178,7 @@ mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() {
for (int row = 0; row < M; row++)
for (int col = 0; col < N; col++)
- result.at(row, col) = at(col, row);
+ result[row][col] = this[col][row];
return result;
}
@@ -203,3 +257,38 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>
return os;
}
+
+
+/* square matrix specializaiton */
+
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::basic_matrix<T, N, N> {
+public:
+ /// in place transpose
+ void transpose();
+ inline void tr() { transpose(); }
+
+ /// in place inverse
+ void invert();
+};
+
+
+template<typename T, std::size_t N>
+void mm::square_matrix<T, N>::transpose() {
+ for (int row = 0; row < N; row++)
+ for (int col = 0; col < row; col++)
+ std::swap(this->at(row, col), this->at(col, row));
+}
+
+
+/* row vector specialization */
+template<typename T, std::size_t Rows>
+class mm::row_vec : public mm::basic_matrix<T, Rows, 1> {
+public:
+};
+
+/* column vector specialization */
+template<typename T, std::size_t Cols>
+class mm::col_vec : public mm::basic_matrix<T, 1, Cols> {
+public:
+};