summaryrefslogtreecommitdiffstats
path: root/include/mm
diff options
context:
space:
mode:
authorNao Pross <naopross@thearcway.org>2019-02-23 00:58:50 +0100
committerNao Pross <naopross@thearcway.org>2019-02-23 00:58:50 +0100
commit69bfc145590498104e9ed960c38c4e63b1c5c952 (patch)
tree2d741282e1b42b54dcde76b85f0e93fafd95a6b6 /include/mm
parentMerge branch 'master' into matrices (diff)
downloadlibmm-69bfc145590498104e9ed960c38c4e63b1c5c952.tar.gz
libmm-69bfc145590498104e9ed960c38c4e63b1c5c952.zip
Add +, -, * operators to mm::basic_matrix
Diffstat (limited to '')
-rw-r--r--include/mmmatrix.hpp160
1 files changed, 147 insertions, 13 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 7934d87..51233cb 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -3,12 +3,8 @@
*
* This library is not intended to be _performant_, it does not contain
* hand written SMID / SSE / AVX optimizations. It is instead an example
- * of highly abstracted code, where matrices can contain any data type.
- *
- * As a challenge, the matrix data structure has been built on a container
- * of static capacity. But if a dynamic base container is needed, the code
- * should be easily modifiable to add further abstraction, by templating
- * the container and possibly the allocator.
+ * of highly inefficient (but abstract!) code, where matrices can contain any
+ * data type.
*
* Naoki Pross <naopross@thearcway.org>
* 2018 ~ 2019
@@ -21,6 +17,12 @@
namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
+
+ template<typename T, std::size_t Rows>
+ using row_vec = basic_matrix<T, Rows, 1>;
+
+ template<typename T, std::size_t Cols>
+ using col_vec = basic_matrix<T, 1, Cols>;
}
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -31,22 +33,50 @@ public:
static constexpr std::size_t rows = Rows;
static constexpr std::size_t cols = Cols;
- basic_matrix() {}
+ basic_matrix(const basic_matrix<T, Rows, Cols>& other);
+ basic_matrix(basic_matrix<T, Rows, Cols>&& other);
+
+ template<std::size_t ORows, std::size_t OCols>
+ basic_matrix(const basic_matrix<T, ORows, OCols>& other);
+
+ // access data
+ T& at(std::size_t row, std::size_t col);
+
+ void swap_rows(std::size_t x, std::size_t y);
+ void swap_cols(std::size_t x, std::size_t y);
+
+ // mathematical operations
+ basic_matrix<T, Cols, Rows> transposed();
+ inline basic_matrix<T, Cols, Rows> t() { return transposed(); }
- template<std::size_t Rows_, std::size_t Cols_>
- basic_matrix(const basic_matrix<T, Rows_, Cols_>& other);
+ // bool is_invertible();
+ // bool invert();
+ // basic_matrix<T, Rows, Cols> inverse();
- const T& at(std::size_t row, std::size_t col);
+ inline constexpr bool is_square() {
+ return (Rows == Cols);
+ }
private:
- std::array<T, (Rows * Cols)> data;
+ T data[Rows][Cols] = {};
};
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(const mm::basic_matrix<T, Rows, Cols>& other) {
+ for (int row = 0; row < Rows; row++)
+ for (int col = 0; col < Cols; col++)
+ data[row][col] = other.data[row][col];
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(mm::basic_matrix<T, Rows, Cols>&& other) {
+ data = other.data;
+}
template<typename T, std::size_t Rows, std::size_t Cols>
template<std::size_t ORows, std::size_t OCols>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(const basic_matrix<T, ORows, OCols>& other) {
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(const mm::basic_matrix<T, ORows, OCols>& other) {
static_assert((ORows <= Rows),
"cannot copy a taller matrix into a smaller one"
);
@@ -55,5 +85,109 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(const basic_matrix<T, ORows, OCols
"cannot copy a larger matrix into a smaller one"
);
- std::copy(std::begin(other.data), std::end(other.data), data.begin());
+ for (int row = 0; row < Rows; row++)
+ for (int col = 0; col < Cols; col++)
+ data[row][col] = other.data[row][col];
+}
+
+
+/* member functions */
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+T& mm::basic_matrix<T, Rows, Cols>::at(std::size_t row, std::size_t col) {
+ static_assert(row < Rows, "out of row bound");
+ static_assert(col < Cols, "out of column bound");
+
+ return data[row][col];
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+void mm::basic_matrix<T, Rows, Cols>::swap_rows(std::size_t x, std::size_t y) {
+ if (x == y)
+ return;
+
+ for (int col = 0; col < Cols; col++)
+ std::swap(data[x][col], data[y][col]);
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+void mm::basic_matrix<T, Rows, Cols>::swap_cols(std::size_t x, std::size_t y) {
+ if (x == y)
+ return;
+
+ for (int row = 0; row < rows; row++)
+ std::swap(data[row][x], data[row][y]);
+}
+
+template<typename T, std::size_t M, std::size_t N>
+mm::basic_matrix<T, N, M> mm::basic_matrix<T, M, N>::transposed() {
+ mm::basic_matrix<T, N, M> result;
+
+ for (int row = 0; row < M; row++)
+ for (int col = 0; col < N; col++)
+ result.at(row, col) = at(col, row);
+
+ return result;
+}
+
+
+/* operator overloading */
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols> operator+(
+ const mm::basic_matrix<T, Rows, Cols>& a,
+ const mm::basic_matrix<T, Rows, Cols>& b
+) {
+ mm::basic_matrix<T, Rows, Cols> result;
+
+ for (int row = 0; row < Rows; row++)
+ for (int col = 0; col < Cols; col++)
+ result.at(row, col) = a.at(row, col) + a.at(row, col);
+
+ return result;
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols> operator*(
+ const mm::basic_matrix<T, Rows, Cols>& m,
+ const T& scalar
+) {
+ mm::basic_matrix<T, Rows, Cols> result;
+ for (int row = 0; row < Rows; row++)
+ for (int col = 0; col < Cols; col++)
+ result.at(row, col) = m.at(row, col) * scalar;
+
+ return result;
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols> operator*(
+ const T& scalar,
+ const mm::basic_matrix<T, Rows, Cols>& m
+) {
+ return m * scalar;
+}
+
+template<typename T, std::size_t M, std::size_t P, std::size_t N>
+mm::basic_matrix<T, M, N> operator*(
+ const mm::basic_matrix<T, M, P>& a,
+ const mm::basic_matrix<T, P, N>& b
+) {
+ mm::basic_matrix<T, M, N> result;
+
+ // TODO: use a more efficient algorithm
+ for (int row = 0; row < M; row++)
+ for (int col = 0; col < N; col++)
+ for (int k = 0; k < P; k++)
+ result.at(row, col) = a.at(row, k) * b.at(k, col);
+
+ return result;
+}
+
+
+template<typename T, std::size_t Rows, std::size_t Cols>
+mm::basic_matrix<T, Rows, Cols> operator-(
+ const mm::basic_matrix<T, Rows, Cols>& a,
+ const mm::basic_matrix<T, Rows, Cols>& b
+) {
+ return a + static_cast<T>(-1) * b;
}