summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/mmmatrix.hpp97
-rw-r--r--test/matrix_example.cpp27
2 files changed, 90 insertions, 34 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 6a31610..5285429 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -22,20 +22,24 @@ namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
+ /* specialization of basic_matrx for Cols = 1 */
+ template<typename T, std::size_t Rows>
+ class row_vec;
+
+ /* specialization of basic_matrx for Rows = 1 */
+ template<typename T, std::size_t Cols>
+ class col_vec;
+
+ /* shorter name for basic_matrix */
template<typename T, std::size_t Rows, std::size_t Cols>
class matrix;
+ /* specialization of basic_matrix for Rows == Cols */
template<typename T, std::size_t N>
class square_matrix;
// template<typename T, std::size_t N>
// class diag_matrix;
-
- template<typename T, std::size_t Rows>
- class row_vec;
-
- template<typename T, std::size_t Cols>
- class col_vec;
}
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -73,7 +77,7 @@ public:
// mathematical operations
virtual basic_matrix<T, Cols, Rows> transposed() const;
- inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); }
+ inline basic_matrix<T, Cols, Rows> td() const { return transposed(); }
// bool is_invertible() const;
// basic_matrix<T, Rows, Cols> inverse() const;
@@ -102,8 +106,8 @@ public:
}
protected:
- template<typename Iterator>
- basic_matrix(Iterator begin, Iterator end);
+ template<typename ConstIterator>
+ basic_matrix(ConstIterator begin, ConstIterator end);
private:
std::array<T, Rows * Cols> data;
@@ -157,8 +161,10 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
/* protected construtor */
template<typename T, std::size_t Rows, std::size_t Cols>
-template<typename Iterator>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) {
+template<typename ConstIterator>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(
+ ConstIterator begin, ConstIterator end
+) {
assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols)));
std::copy(begin, end, data.begin());
}
@@ -188,8 +194,8 @@ auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
return data.at(index);
} else {
return row_vec<T, Rows>(
- data.begin() + (index * Cols),
- data.begin() + ((index + 1) * Cols) + 1
+ data.cbegin() + (index * Cols),
+ data.cbegin() + ((index + 1) * Cols) + 1
);
}
}
@@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>
}
-/* square matrix specializaiton */
-
-template<typename T, std::size_t N>
-class mm::square_matrix : public mm::basic_matrix<T, N, N> {
-public:
- /// in place transpose
- void transpose();
- inline void tr() { transpose(); }
-
- /// in place inverse
- void invert();
-};
-
-
-template<typename T, std::size_t N>
-void mm::square_matrix<T, N>::transpose() {
- for (unsigned row = 0; row < N; row++)
- for (unsigned col = 0; col < row; col++)
- std::swap(this->at(row, col), this->at(col, row));
-}
/* row vector specialization */
@@ -336,8 +322,53 @@ public:
using mm::basic_matrix<T, 1, Cols>::basic_matrix;
};
+/* general specialization (alias) */
template<typename T, std::size_t Rows, std::size_t Cols>
class mm::matrix : public mm::basic_matrix<T, Rows, Cols> {
public:
using mm::basic_matrix<T, Rows, Cols>::basic_matrix;
};
+
+/* square matrix specializaiton */
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::basic_matrix<T, N, N> {
+public:
+ using mm::basic_matrix<T, N, N>::basic_matrix;
+
+ /// in place transpose
+ void transpose();
+ inline void t() { transpose(); }
+
+ T trace();
+ inline T tr() { return trace(); }
+
+ /// in place inverse
+ void invert();
+
+ // get the identity of size N
+ static inline constexpr square_matrix<T, N> identity() {
+ square_matrix<T, N> i;
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < N; col++)
+ i.at(row, col) = (row == col) ? 1 : 0;
+
+ return i;
+ }
+};
+
+
+template<typename T, std::size_t N>
+void mm::square_matrix<T, N>::transpose() {
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < row; col++)
+ std::swap(this->at(row, col), this->at(col, row));
+}
+
+template<typename T, std::size_t N>
+T mm::square_matrix<T, N>::trace() {
+ T sum = 0;
+ for (unsigned i = 0; i < N; i++)
+ sum += this->at(i, i);
+
+ return sum;
+}
diff --git a/test/matrix_example.cpp b/test/matrix_example.cpp
index 26aeede..469cbff 100644
--- a/test/matrix_example.cpp
+++ b/test/matrix_example.cpp
@@ -12,18 +12,43 @@ int main(int argc, char *argv[]) {
std::cout << "a = \n" << a;
std::cout << "b = \n" << b;
std::cout << "c = \n" << c;
+ std::cout << std::endl;
// access elements
+ std::cout << "Access elements" << std::endl;
std::cout << "a.at(2,0) = " << a.at(2, 0) << std::endl;
std::cout << "a[2][0] = " << a[2][0] << std::endl;;
+ std::cout << std::endl;
// basic operations
+ std::cout << "Basic operations" << std::endl;
std::cout << "a + b = \n" << a + b;
std::cout << "a - b = \n" << a - b;
std::cout << "a * c = \n" << a * c;
std::cout << "a * 2 = \n" << a * 2;
std::cout << "2 * a = \n" << 2 * a;
- std::cout << "tr(a) = \n" << a.trd();
+ std::cout << "a.td() = \n" << a.td(); // or a.trasposed();
+ std::cout << std::endl;
+
+ // special matrices
+ mm::square_matrix<std::complex<int>, 2> f {{{2, 3}, {1, 4}}, {{6, 1}, {-3, 4}}};
+
+ std::cout << "Square matrix" << std::endl;
+ std::cout << "f = \n" << f;
+
+ std::cout << "tr(f) = " << f.tr() /* or f.trace() */ << std::endl;
+
+ f.t();
+ std::cout << "after in place transpose f.t(), f = \n" << f;
+ std::cout << std::endl;
+
+
+ auto identity = mm::square_matrix<int, 3>::identity();
+
+ std::cout << "Identity matrix" << std::endl;
+ std::cout << "I = \n" << identity;
+ std::cout << std::endl;
+
return 0;
}