summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/mm/mmmatrix.hpp120
-rw-r--r--include/mm/view.hpp134
-rw-r--r--test/matrix_example.cpp72
3 files changed, 233 insertions, 93 deletions
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index d653d72..0f8833e 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -11,13 +11,19 @@
*/
#pragma once
+#include "mm/debug.hpp"
+
#include <iostream>
+#include <iomanip>
#include <cassert>
#include <initializer_list>
#include <array>
#include <memory>
+/*
+ * Forward declarations
+ */
namespace mm {
using index = std::size_t;
@@ -29,46 +35,96 @@ namespace mm {
class matrix;
template<typename T, std::size_t N>
+ class vector;
+
+ template<typename T, std::size_t N>
class square_matrix;
template<typename T, std::size_t N>
class diagonal_matrix;
+
}
/*
- * Matrix class, no access methods
+ * Matrix Classes
*/
namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
- class basic_matrix
+ struct basic_matrix
{
public:
using type = T;
+ static constexpr std::size_t rows = Rows;
+ static constexpr std::size_t cols = Cols;
+
+
template<typename U, std::size_t ORows, std::size_t OCols>
- friend class mm::matrix;
+ friend class mm::basic_matrix;
+
+ virtual ~basic_matrix() {};
// copy from another matrix
- template<std::size_t ORows, std::size_t OCols>
- matrix(const basic_matrix<T, ORows, OCols>& other);
+ // template<std::std::size_t ORows, std::size_t OCols>
+ // basic_matrix(const basic_matrix<T, ORows, OCols>& other) {
+ // static_assert(ORows <= Rows);
+ // static_assert(OCols <= Cols);
+
+ // for (index row = 0; row < Rows; row++)
+ // for (index col = 0; col < Cols; col++)
+ // at(row, col) = other.at(row, col);
+ // }
virtual T& at(index row, index col) = 0;
- virtual const T& at(index row, index col) const = 0;
+ virtual const T& at(index row, index col) const = 0;
+
+ // constexpr std::size_t rows() { return Rows; }
+ // constexpr std::size_t cols() { return Cols; }
+
+ protected:
+ basic_matrix() {
+ npdebug("default construtor");
+ }
+
+ basic_matrix(const basic_matrix<T, Rows, Cols>& other) {
+ npdebug("copy constructor");
+ }
+
+ basic_matrix(basic_matrix<T, Rows, Cols>&& other) {
+ npdebug("move constructor");
+ }
};
+
/* Specializations */
template<typename T, std::size_t Rows, std::size_t Cols>
- struct matrix : public basic_matrix<T, N>
+ struct matrix : public basic_matrix<T, Rows, Cols>
{
public:
+ // aggregate initialization
+ template<typename ...E,
+ typename std::enable_if<
+ std::is_convertible<E, T>::value
+ >::type...
+ >
+ matrix(E ...e) : m_data({{std::forward<E>(e)...}}) {}
+
+ matrix(const matrix<T, Rows, Cols>& o)
+ : basic_matrix<T, Rows, Cols>(o), m_data(o.m_data) {}
+
+ matrix(matrix<T, Rows, Cols>&& o)
+ : basic_matrix<T, Rows, Cols>(std::move(o)), m_data(std::move(o.m_data)) {}
+
+ virtual ~matrix() = default;
+
virtual T& at(index row, index col) override {
return m_data[row * Cols + col];
}
virtual const T& at(index row, index col) const override {
- return at(row, col);
+ return m_data[row * Cols + col];
}
private:
@@ -80,20 +136,8 @@ namespace mm {
struct vector : public matrix<T, 1, N> {};
template<typename T, std::size_t N>
- struct square_matrix : public basic_matrix<T, N>
- {
- public:
- virtual T& at(index row, index col) override {
- return m_data[row * N + col];
- }
-
- virtual const T& at(index row, index col) const override {
- return at(row, col);
- }
-
- private:
- std::array<T, N*N> m_data;
- };
+ struct square_matrix : public matrix<T, N, N>
+ {};
template<typename T, std::size_t N>
struct identity_matrix : public basic_matrix<T, N, N>
@@ -104,17 +148,17 @@ namespace mm {
}
private:
- T m_useless;
- T& at(index row, index col) { return m_useless; }
- }
+ // not allowed
+ T& at(index row, index col) { return static_cast<T>(0); }
+ };
template<typename T, std::size_t N>
struct diagonal_matrix : public basic_matrix<T, N, N>
{
public:
T& at(index row, index col) override {
- n_null_element = static_cast<T>(0);
- return (row != col) ? m_data[row] : n_null_element;
+ m_null_element = static_cast<T>(0);
+ return (row != col) ? m_data[row] : m_null_element;
}
const T& at(index row, index col) const override {
@@ -124,5 +168,27 @@ namespace mm {
private:
T m_null_element;
std::array<T, N> m_data;
+ };
+}
+
+/*
+ * Matrix Opertors
+ */
+
+namespace mm {
+}
+
+
+template<typename T, std::size_t Rows, std::size_t Cols, unsigned NumW = 3>
+std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>& m) {
+ for (mm::index row = 0; row < Rows; row++) {
+ os << "[ ";
+ for (mm::index col = 0; col < (Cols -1); col++) {
+ os << std::setw(NumW) << m.at(row, col) << ", ";
+ }
+ os << std::setw(NumW) << m.at(row, (Cols -1)) << " ]\n";
}
+
+ return os;
}
+ \ No newline at end of file
diff --git a/include/mm/view.hpp b/include/mm/view.hpp
index 910c16a..e701a78 100644
--- a/include/mm/view.hpp
+++ b/include/mm/view.hpp
@@ -1,66 +1,112 @@
#pragma once
-#include <mmmatrix.hpp>
+#include "mm/mmmatrix.hpp"
+#include <variant>
+#include <tuple>
+#include <type_traits>
+#include <functional>
-namespace mm::alg {
- template <
- template<typename, std::size_t, std::size_t> typename Matrix,
- typename T, std::size_t Rows, std::size_t Cols
- >
- struct visitor
- {
- using type = T;
+namespace mm {
+
+ namespace algorithm {
+ // does nothing
+ struct visit
+ {
+ visit() = default;
- // copy constructible
- visitor(const visitor& other) = default;
+ template<typename Matrix>
+ void operator()(Matrix& m) {}
+ };
- T& operator()(const Matrix<T, Rows, Cols>& m, index row, index col) {
- return m.at(row, col);
- }
+ struct transpose : public visit
+ {
+ /// does not work with non-square matrices
+ template<typename Matrix>
+ void operator()(Matrix& m) {
+ static_assert(Matrix::rows == Matrix::cols);
+ // naiive impl
+ for (index r = 0; r < m.rows / 2; r++)
+ for (index c = 0; c < m.cols; c++)
+ if (c != r)
+ std::swap(m.at(r, c), m.at(c, r));
+ }
+ };
- const T& operator()(const Matrix<T, Rows, Cols>& m, index row, index col) {
- return operator()(m, row, col);
- }
- };
+ /// algorithm aliases
+ using tr = transpose;
+ }
+
+ /// namespace alias
+ namespace alg = algorithm;
- template <
- template<typename, std::size_t, std::size_t> typename Matrix,
- typename T, std::size_t Rows, std::size_t Cols
- >
- struct transpose : public visitor<Matrix, T, Rows, Cols>
+ template<typename Matrix>
+ struct clone
{
- T& operator()(const Matrix<T, Rows, Cols> m, index row, index col) {
- // assert(col < Rows)
- // assert(row < Cols)
- return m.at(col, row);
+ Matrix matrix;
+
+ explicit clone(Matrix&& m) : matrix(m) {}
+ explicit clone(const Matrix& m) : matrix(m) {}
+
+ operator Matrix() {
+ return std::move(matrix);
}
};
-}
-namespace mm {
- template <
- template<typename, std::size_t, std::size_t> typename Matrix,
- typename T, std::size_t Rows, std::size_t Cols
- >
- struct view
+ template<typename Matrix, typename ...Algs>
+ struct mutate
{
- Matrix<T, Rows, Cols>& m;
- // std::stack<std::unique_ptr<alg::visitor>> visitors;
- std::unique_ptr<alg::visitor> visitor;
+ Matrix& matrix;
+ std::tuple<Algs...> visitors;
+
+ explicit mutate(Matrix& m) : matrix(m) {}
+
+ template<typename ...OAlgs, typename Alg>
+ explicit mutate(Matrix& m, std::tuple<OAlgs...>&& t, Alg&& v)
+ : matrix(m)
+ {
+ /// append the new operator
+ visitors = std::tuple_cat(t, std::make_tuple(v));
+ }
+
+ ~mutate() {
+ visit();
+ }
- T& at(index row, index col) {
- return visitor(m, row, col);
+ void visit() {
+ std::apply([this](auto&&... v) {
+ (v(matrix),...);
+ }, visitors);
}
- view& operator|=(const alg::visitor& other) {
- // visitors.push(std::move(std::make_unique<alg::visitor>(other)));
- visitor = std::make_unique<alg::visistor>(other);
+ operator Matrix() {
+ return std::move(matrix);
}
};
- view operator|(const view& left, const alg::visitor& right) {
- return left |= right;
+ template<typename Matrix, typename Alg>
+ clone<Matrix> operator|(clone<Matrix>&& cl, Alg&& v) {
+ static_assert(std::is_convertible<Alg, alg::visit>::value);
+ /// apply alg operator
+ v(cl.matrix);
+ /// forward to next alg
+ return clone<Matrix>(std::move(cl));
+ }
+
+ template<typename Matrix, typename ...Algs, typename Alg>
+ mutate<Matrix, Algs..., Alg> operator|(mutate<Matrix, Algs...>&& mut, Alg&& v) {
+ static_assert(std::is_convertible<Alg, alg::visit>::value);
+ /// append alg to the visitors tuple
+ return mutate<Matrix, Algs..., Alg>(
+ mut.matrix,
+ std::move(mut.visitors),
+ v
+ );
+ }
+
+ template<typename Matrix, typename Alg>
+ mutate<Matrix, Alg> operator|(Matrix& m, Alg&& v) {
+ return mutate(m) | std::move(v);
}
}
diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp
index 96ba67d..af95a00 100644
--- a/test/matrix_example.cpp
+++ b/test/matrix_example.cpp
@@ -1,36 +1,64 @@
#include "mm/mmmatrix.hpp"
+#include "mm/view.hpp"
#include <iostream>
#include <complex>
-int main(int argc, char *argv[]) {
- std::cout << "MxN dimensional (int) matrices" << std::endl;
- mm::matrix<int, 3, 2> a {{1, 2}, {3, 4}, {5, 6}};
- mm::matrix<int, 3, 2> b {{4, 3}, {9, 1}, {2, 5}};
- mm::matrix<int, 2, 4> c {{1, 2, 3, 4}, {5, 6, 7, 8}};
- auto ct = c.t();
+// int main(int argc, char *argv[]) {
+int main() {
+ // std::cout << "MxN dimensional (int) matrices" << std::endl;
- std::cout << "a = \n" << a;
- std::cout << "b = \n" << b;
- std::cout << "c = \n" << c;
- std::cout << "c^t = \n" << ct;
- std::cout << std::endl;
+ mm::matrix<int, 2, 2> a { 1, 2, 3, 4 };
+
+ // mm::matrix<int, 3, 2> a {{1, 2}, {3, 4}, {5, 6}};
+ // mm::matrix<int, 3, 2> a {1, 2, 3, 4, 5, 6};
+
+ // mm::matrix<int, 3, 2> b {4, 3, 9, 1, 2, 5};
+ // mm::matrix<int, 2, 4> c {1, 2, 3, 4, 5, 6, 7, 8};
+
+ // std::cout << "a = \n" << a;
+ // std::cout << "b = \n" << b;
+ // std::cout << "c = \n" << c;
+ // std::cout << std::endl;
// access elements
- std::cout << "Access elements" << std::endl;
- std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl;
- std::cout << "a[2][0] = " << a[2][0] << std::endl;;
- std::cout << std::endl;
+ // std::cout << "Access elements" << std::endl;
+ // std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl;
+ // std::cout << "a[2][0] = " << a[2][0] << std::endl;;
+ // std::cout << std::endl;
// basic operations
- std::cout << "Basic operations" << std::endl;
- std::cout << "a + b = \n" << a + b;
- std::cout << "a - b = \n" << a - b;
- std::cout << "a * c = \n" << a * c;
- std::cout << "a * 2 = \n" << a * 2;
- std::cout << "2 * a = \n" << 2 * a;
- std::cout << "a.td() = \n" << a.t(); // or a.trasposed();
+ // std::cout << "Basic operations" << std::endl;
+ // std::cout << "a + b = \n" << a + b;
+ // std::cout << "a - b = \n" << a - b;
+ // std::cout << "a * c = \n" << a * c;
+ // std::cout << "a * 2 = \n" << a * 2;
+ // std::cout << "2 * a = \n" << 2 * a;
+ // std::cout << "a.td() = \n" << a.t(); // or a.trasposed();
+ // std::cout << std::endl;
+
+ std::cout << "a = \n" << a;
+
+ std::cout << "Cloning a" << std::endl;
+ decltype(a) e = mm::clone(a) | mm::alg::transpose();
+ std::cout << "e = \n" << e;
std::cout << std::endl;
+ std::cout << "Mutating a" << std::endl;
+ mm::mutate(a) | mm::alg::transpose();
+ std::cout << "a = \n" << a;
+ std::cout << std::endl;
+
+ a | mm::alg::tr();
+ std::cout << "a = \n" << a;
+
+ // std::cout << "Converting clone object" << std::endl;
+ // mm::matrix<int, 2, 2> g = e;
+ // std::cout << std::endl;
+
+ // std::cout << "Converting mutate object" << std::endl;
+ // mm::matrix<int, 2, 2> h = f;
+ // std::cout << std::endl;
+
return 0;
}