summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/mm/mmiterator.hpp6
-rw-r--r--include/mm/mmmatrix.hpp198
-rw-r--r--include/mm/mmvec.hpp21
-rw-r--r--test/matrix_example.cpp6
4 files changed, 171 insertions, 60 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp
index ed406fe..4efe474 100644
--- a/include/mm/mmiterator.hpp
+++ b/include/mm/mmiterator.hpp
@@ -28,13 +28,13 @@ public:
vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0)
: M(_M), position(pos), index(i) {}
-#ifdef MM_IMPLICIT_CONVERSION_ITERATOR
+//#ifdef MM_IMPLICIT_CONVERSION_ITERATOR
operator T&()
{
- npdebug("Calling +")
+ //npdebug("Calling +")
return *(*this);
}
-#endif
+//#endif
IterType operator++()
{
diff --git a/include/mm/mmmatrix.hpp b/include/mm/mmmatrix.hpp
index 0b7a40b..facbb48 100644
--- a/include/mm/mmmatrix.hpp
+++ b/include/mm/mmmatrix.hpp
@@ -23,41 +23,30 @@
namespace mm {
+ /* basic grid structure */
+
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
- /* specialisations */
+ /* basic wrapper */
template<typename T, std::size_t Rows, std::size_t Cols>
class matrix; // simple matrix format
+
+ /* specialisations */
- /* specialization of basic_matrx for Cols = 1 */
- template<typename T, std::size_t Rows>
- class row_vec;
-
- /* specialization of basic_matrx for Rows = 1 */
- template<typename T, std::size_t Cols>
- class col_vec; // transposed version of row_vec
+ /* specialization of a matrix */
+ template<typename T, std::size_t N>
+ class vector; // by default, set a column vector
template<typename T, std::size_t N>
class square_matrix;
/* specialisation of a square_matrix for a sub-diagonal composed matrix */
- template<typename T, std::size_t N, std::size_t K = 0>
- class diag_matrix;
+ template<typename T, std::size_t N, signed long ... Diags>
+ class multi_diag_matrix;
}
-
-/*namespace mm {
-
- template<typename T, std::size_t N>
- using diag_iterator = vector_iterator<T, N, N, MM_DIAG_ITER, mm::basic_matrix<T, N, N>>;
-
- template<typename T, std::size_t N>
- using const_diag_iterator = vector_iterator<typename std::add_const<T>::type, N, N, MM_DIAG_ITER, typename std::add_const<mm::basic_matrix<T, N, N>>::type>;
-}*/
-
-
/*
* Matrix class, no access methods
*/
@@ -211,11 +200,6 @@ public:
other.M = nullptr;
}
- // copy from another matrix
- /*template<std::size_t ORows, std::size_t OCols>
- matrix(const matrix<T, ORows, OCols>& other)
- : M(std::make_shared<mm::basic_matrix<T, Rows, Cols>(*other.M)), transposed(other.transposed) {} */
-
matrix<T, Rows, Cols> operator=(const basic_matrix<T, Rows, Cols>& other) // deep copy
{
*M = *other.M;
@@ -232,12 +216,9 @@ public:
return *this;
}
- matrix<T, Rows, Cols> transpose() const
+ const matrix<T, Rows, Cols> transpose() const
{
- auto m = shallow_cpy();
- m.transposed = !transposed;
-
- return m;
+ return matrix<T, Rows, Cols>(M, !transposed);
}
inline matrix<T, Rows, Cols>& td()
@@ -277,28 +258,24 @@ public:
return static_cast<square_matrix<T, Rows>>(*this);
}
- /// downcast to row_vector
- static inline constexpr bool is_row_vec() { return (Cols == 1); }
- inline constexpr row_vec<T, Rows> to_row_vec() const {
- static_assert(is_row_vec());
- return static_cast<row_vec<T, Rows>>(*this);
- }
-
/// downcast to col_vector
- static inline constexpr bool is_col_vec() { return (Rows == 1); }
- inline constexpr col_vec<T, Cols> to_col_vec() const {
- static_assert(is_col_vec());
- return static_cast<col_vec<T, Cols>>(*this);
+ static inline constexpr bool is_vector() { return (Rows == 1 || Cols == 1); }
+ inline vector<T, Cols> to_vector() const {
+ if constexpr(Cols == 1)
+ return static_cast<vector<T, Rows>>(*this);
+ else if (Rows == 1)
+ return vector<T, Cols>(*this); // copy into column vector
+
}
/* Accessors */
- T& at(std::size_t row, std::size_t col)
+ virtual T& at(std::size_t row, std::size_t col)
{
return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col];
}
- const T& at(std::size_t row, std::size_t col) const
+ virtual const T& at(std::size_t row, std::size_t col) const
{
return (transposed) ? M->data[col * Cols + row] : M->data[row * Cols + col];
}
@@ -311,12 +288,12 @@ public:
return (transposed) ? Rows : Cols;
}
- mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index)
+ virtual mm::matrix<T, Rows, Cols>::vec_iterator operator[](std::size_t index)
{
return mm::matrix<T, Rows, Cols>::vec_iterator(*M, index, 0, !transposed);
}
- mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const
+ virtual mm::matrix<T, Rows, Cols>::const_vec_iterator operator[](std::size_t index) const
{
return mm::matrix<T, Rows, Cols>::const_vec_iterator(*M, index, 0, !transposed);
}
@@ -358,12 +335,7 @@ protected:
// shallow construction
matrix(std::shared_ptr<mm::basic_matrix<T, Rows, Cols>> grid, bool tr = false) : M(grid), transposed(tr) {}
-
- matrix<T, Rows, Cols> shallow_cpy() const
- {
- return matrix<T, Rows, Cols>(M, transposed);
- }
-
+
private:
bool transposed;
@@ -416,7 +388,7 @@ mm::matrix<T, M, N> operator*(
mm::matrix<T, M, N> result;
const mm::matrix<T, P2, N> bt = b.t(); // weak transposition
- npdebug("Calling *")
+ //npdebug("Calling *")
for (unsigned row = 0; row < M; row++)
for (unsigned col = 0; col < N; col++)
@@ -444,6 +416,33 @@ std::ostream& operator<<(std::ostream& os, const mm::matrix<T, Rows, Cols>& m) {
}
/*
+ * Vector, TODO better manage column and row
+ */
+
+template<typename T, std::size_t N>
+class mm::vector : public mm::matrix<T, N, 1>
+{
+public:
+
+ using mm::matrix<T, N, 1>::matrix;
+
+ vector(std::initializer_list<T> l)
+ : mm::matrix<T, N, 1>(l) {}
+};
+
+template<typename T>
+mm::vector<T, 3> operator^(const mm::vector<T, 3>& v, const mm::vector<T, 3>& w)
+{
+ mm::vector<T, 3> out;
+
+ out[0] = v[1] * w[2] - v[2] * w[2];
+ out[1] = v[2] * w[0] - v[0] * w[2];
+ out[2] = v[0] * w[1] - v[1] * w[0];
+
+ return out;
+}
+
+/*
* Square matrix
*/
@@ -458,15 +457,15 @@ public:
using const_diag_iterator = mm::iter::diag_iterator<typename std::add_const<T>::type, N, typename std::add_const<mm::basic_matrix<T, N, N>>::type>;
- T trace();
+ virtual T trace();
inline T tr() { return trace(); }
- mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0)
+ virtual mm::square_matrix<T, N>::diag_iterator diag_beg(int row = 0)
{
return diag_iterator(*(this->M), row, 0);
}
- mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const
+ virtual mm::square_matrix<T, N>::const_diag_iterator diag_end(int row = 0) const
{
return const_diag_iterator(*(this->M), row, N);
}
@@ -502,4 +501,91 @@ T mm::square_matrix<T, N>::trace()
return sum;
}
+// TODO, static assert, for all: Diags > -N, Diags < N
+// TODO, force Diags to be ordered
+template<typename T, std::size_t N, signed long ... Diags>
+class mm::multi_diag_matrix : public mm::square_matrix<T, N>
+{
+ T& shared_zero = 0;
+
+public:
+ using mm::square_matrix<T, N>::square_matrix;
+
+ // TODO, ordered case: dichotomy search O(log(M))
+ // M = parameter pack size
+ static inline bool constexpr is_in(std::size_t i, std::size_t j)
+ {
+ auto t = std::make_tuple(Diags...);
+
+ for(unsigned k(0); k < sizeof...(Diags); ++k)
+ if ((i - j) == std::get<k>(t))
+ return true;
+
+ return false;
+ }
+
+ virtual T& at(std::size_t row, std::size_t col) override
+ {
+ if (is_in(row, col))
+ return mm::square_matrix<T, N>::at(row, col);
+
+ shared_zero = 0;
+ return shared_zero;
+ }
+
+ virtual const T& at(std::size_t row, std::size_t col) const override
+ {
+ if (is_in(row, col))
+ return mm::square_matrix<T, N>::at(row, col);
+
+ return 0;
+ }
+
+
+
+ // TODO, implement limited iterators
+};
+
+/*template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M>
+void constexpr diag_mult(const mm::multi_diag_matrix<T, N, Diags...>& a,
+ const mm::matrix<T, P, M>& b, mm::matrix<T, M, N>& result)
+{
+ static_assert(N == P && N == M, "invalid diagonal multiplication");
+
+ auto d = a.diagonal(Diags);
+ if constexpr (Diags < 0) {
+ for (unsigned k = 0; k < M; ++k)
+ for (unsigned i = -Diags; i < N; ++i)
+ result.at(i + Diags, k) += d[i + Diags] * b.at(i, k);
+ } else {
+ for (unsigned k = 0; k < M; ++k)
+ for (unsigned i = Diags; i < N; ++i)
+ result.at(i, k) += d[i - Diags] * b.at(i - Diags, k);
+ }
+}
+
+template<typename T, std::size_t N, signed long ... Diags, std::size_t P, std::size_t M>
+mm::matrix<T, M, N> operator*(
+ const mm::multi_diag_matrix<T, N, Diags...>& a,
+ const mm::matrix<T, P, M>& b
+) {
+ static_assert(N == P && N == M, "invalid matrix multiplication");
+ assert(a.cols() == b.rows());
+
+ mm::matrix<T, M, N> result;
+ ((
+ auto d = a.diagonal(Diags);
+ if constexpr (Diags < 0) {
+ for (unsigned k = 0; k < M; ++k)
+ for (unsigned i = -Diags; i < N; ++i)
+ result.at(i + Diags, k) += d[i + Diags] * b.at(i, k);
+ } else {
+ for (unsigned k = 0; k < M; ++k)
+ for (unsigned i = Diags; i < N; ++i)
+ result.at(i, k) += d[i - Diags] * b.at(i - Diags, k);
+ }
+ ) ...);
+
+ return result;
+}*/
diff --git a/include/mm/mmvec.hpp b/include/mm/mmvec.hpp
index 1939388..da9040d 100644
--- a/include/mm/mmvec.hpp
+++ b/include/mm/mmvec.hpp
@@ -44,6 +44,27 @@ struct mm::basic_vec : public std::array<T, d> {
using type = T;
static constexpr std::size_t dimensions = d;
+ // convertions
+ static inline constexpr bool is_vec2() {
+ return d == 2;
+ }
+
+ static inline constexpr bool is_vec3() {
+ return d == 3;
+ }
+
+ operator mm::vec2<T>()
+ {
+ static_assert(is_vec2(), "Invalid cast to two dimensional vector");
+ return static_cast<mm::vec2<T>>(*this);
+ }
+
+ operator mm::vec3<T>()
+ {
+ static_assert(is_vec3(), "Invalid cast to three dimensional vector");
+ return static_cast<mm::vec3<T>>(*this);
+ }
+
// TODO: template away these
static constexpr T null_element = static_cast<T>(0);
static constexpr T unit_element = static_cast<T>(1);
diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp
index a835ee0..a3f0eac 100644
--- a/test/matrix_example.cpp
+++ b/test/matrix_example.cpp
@@ -32,7 +32,7 @@ int main(int argc, char *argv[]) {
std::cout << "a.td() = \n" << a.t(); // or a.trasposed();
std::cout << std::endl;
- // special matrices
+ // square matrix
mm::square_matrix<std::complex<int>, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}};
std::cout << "Square matrix" << std::endl;
@@ -50,5 +50,9 @@ int main(int argc, char *argv[]) {
std::cout << "I = \n" << identity;
std::cout << std::endl;
+ // vector
+
+ //
+
return 0;
}