summaryrefslogtreecommitdiffstats
path: root/include/mmmatrix.hpp
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--include/mmmatrix.hpp97
1 files changed, 64 insertions, 33 deletions
diff --git a/include/mmmatrix.hpp b/include/mmmatrix.hpp
index 6a31610..5285429 100644
--- a/include/mmmatrix.hpp
+++ b/include/mmmatrix.hpp
@@ -22,20 +22,24 @@ namespace mm {
template<typename T, std::size_t Rows, std::size_t Cols>
class basic_matrix;
+ /* specialization of basic_matrx for Cols = 1 */
+ template<typename T, std::size_t Rows>
+ class row_vec;
+
+ /* specialization of basic_matrx for Rows = 1 */
+ template<typename T, std::size_t Cols>
+ class col_vec;
+
+ /* shorter name for basic_matrix */
template<typename T, std::size_t Rows, std::size_t Cols>
class matrix;
+ /* specialization of basic_matrix for Rows == Cols */
template<typename T, std::size_t N>
class square_matrix;
// template<typename T, std::size_t N>
// class diag_matrix;
-
- template<typename T, std::size_t Rows>
- class row_vec;
-
- template<typename T, std::size_t Cols>
- class col_vec;
}
template<typename T, std::size_t Rows, std::size_t Cols>
@@ -73,7 +77,7 @@ public:
// mathematical operations
virtual basic_matrix<T, Cols, Rows> transposed() const;
- inline basic_matrix<T, Cols, Rows> trd() const { return transposed(); }
+ inline basic_matrix<T, Cols, Rows> td() const { return transposed(); }
// bool is_invertible() const;
// basic_matrix<T, Rows, Cols> inverse() const;
@@ -102,8 +106,8 @@ public:
}
protected:
- template<typename Iterator>
- basic_matrix(Iterator begin, Iterator end);
+ template<typename ConstIterator>
+ basic_matrix(ConstIterator begin, ConstIterator end);
private:
std::array<T, Rows * Cols> data;
@@ -157,8 +161,10 @@ mm::basic_matrix<T, Rows, Cols>::basic_matrix(
/* protected construtor */
template<typename T, std::size_t Rows, std::size_t Cols>
-template<typename Iterator>
-mm::basic_matrix<T, Rows, Cols>::basic_matrix(Iterator begin, Iterator end) {
+template<typename ConstIterator>
+mm::basic_matrix<T, Rows, Cols>::basic_matrix(
+ ConstIterator begin, ConstIterator end
+) {
assert(static_cast<unsigned>(std::distance(begin, end)) >= ((Rows * Cols)));
std::copy(begin, end, data.begin());
}
@@ -188,8 +194,8 @@ auto mm::basic_matrix<T, Rows, Cols>::operator[](std::size_t index) {
return data.at(index);
} else {
return row_vec<T, Rows>(
- data.begin() + (index * Cols),
- data.begin() + ((index + 1) * Cols) + 1
+ data.cbegin() + (index * Cols),
+ data.cbegin() + ((index + 1) * Cols) + 1
);
}
}
@@ -300,26 +306,6 @@ std::ostream& operator<<(std::ostream& os, const mm::basic_matrix<T, Rows, Cols>
}
-/* square matrix specializaiton */
-
-template<typename T, std::size_t N>
-class mm::square_matrix : public mm::basic_matrix<T, N, N> {
-public:
- /// in place transpose
- void transpose();
- inline void tr() { transpose(); }
-
- /// in place inverse
- void invert();
-};
-
-
-template<typename T, std::size_t N>
-void mm::square_matrix<T, N>::transpose() {
- for (unsigned row = 0; row < N; row++)
- for (unsigned col = 0; col < row; col++)
- std::swap(this->at(row, col), this->at(col, row));
-}
/* row vector specialization */
@@ -336,8 +322,53 @@ public:
using mm::basic_matrix<T, 1, Cols>::basic_matrix;
};
+/* general specialization (alias) */
template<typename T, std::size_t Rows, std::size_t Cols>
class mm::matrix : public mm::basic_matrix<T, Rows, Cols> {
public:
using mm::basic_matrix<T, Rows, Cols>::basic_matrix;
};
+
+/* square matrix specializaiton */
+template<typename T, std::size_t N>
+class mm::square_matrix : public mm::basic_matrix<T, N, N> {
+public:
+ using mm::basic_matrix<T, N, N>::basic_matrix;
+
+ /// in place transpose
+ void transpose();
+ inline void t() { transpose(); }
+
+ T trace();
+ inline T tr() { return trace(); }
+
+ /// in place inverse
+ void invert();
+
+ // get the identity of size N
+ static inline constexpr square_matrix<T, N> identity() {
+ square_matrix<T, N> i;
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < N; col++)
+ i.at(row, col) = (row == col) ? 1 : 0;
+
+ return i;
+ }
+};
+
+
+template<typename T, std::size_t N>
+void mm::square_matrix<T, N>::transpose() {
+ for (unsigned row = 0; row < N; row++)
+ for (unsigned col = 0; col < row; col++)
+ std::swap(this->at(row, col), this->at(col, row));
+}
+
+template<typename T, std::size_t N>
+T mm::square_matrix<T, N>::trace() {
+ T sum = 0;
+ for (unsigned i = 0; i < N; i++)
+ sum += this->at(i, i);
+
+ return sum;
+}