summaryrefslogtreecommitdiffstats
path: root/include/mm/mmiterator.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mm/mmiterator.hpp')
-rw-r--r--include/mm/mmiterator.hpp199
1 files changed, 199 insertions, 0 deletions
diff --git a/include/mm/mmiterator.hpp b/include/mm/mmiterator.hpp
new file mode 100644
index 0000000..c67b92b
--- /dev/null
+++ b/include/mm/mmiterator.hpp
@@ -0,0 +1,199 @@
+#pragma once
+
+namespace mm::iter {
+
+ template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid>
+ class vector_iterator;
+
+ template<typename T, std::size_t Rows, std::size_t Cols, class Grid>
+ class basic_iterator;
+
+ template<typename T, std::size_t N, class Grid>
+ class diag_iterator;
+}
+
+template<typename T, std::size_t Rows, std::size_t Cols, class IterType, class Grid>
+class mm::iter::vector_iterator
+{
+public:
+
+ template<typename U, std::size_t R, std::size_t C, class G>
+ friend class mm::iter::basic_iterator;
+
+ template<typename U, std::size_t N, class G>
+ friend class mm::iter::diag_iterator;
+
+ vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0)
+ : M(_M), position(pos), index(i) {}
+
+ operator T&()
+ {
+ return *(*this);
+ }
+
+ IterType operator++()
+ {
+ IterType it = *this;
+ ++index;
+ return it;
+ }
+
+ IterType operator--()
+ {
+ IterType it = *this;
+ --index;
+ return it;
+ }
+
+ IterType& operator++(int)
+ {
+ ++index;
+ return *this;
+ }
+
+ IterType& operator--(int)
+ {
+ --index;
+ return *this;
+ }
+
+ bool operator==(const IterType& other) const
+ {
+ return index == other.index;
+ }
+
+ bool operator!=(const IterType& other) const
+ {
+ return index != other.index;
+ }
+
+ virtual bool ok() const = 0;
+
+ virtual T& operator*() = 0;
+ virtual T& operator[](std::size_t) = 0;
+
+ IterType begin()
+ {
+ return IterType(M, position, 0);
+ }
+
+ virtual IterType end() = 0;
+
+ /*
+ * Scalar product
+ */
+
+ template<std::size_t P>
+ T operator*(const mm::iter::vector_iterator<T, Rows, P, IterType, Grid>& v)
+ {
+ T out(0);
+
+ for (unsigned k(0); k < Rows; ++k)
+ out += (*this)[k] * v[k];
+
+ return out;
+ }
+
+ template<std::size_t P>
+ T operator*(const mm::iter::vector_iterator<T, P, Cols, IterType, Grid>& v)
+ {
+ T out(0);
+
+ for (unsigned k(0); k < Cols; ++k)
+ out += (*this)[k] * v[k];
+
+ return out;
+ }
+
+protected:
+
+ Grid& M; // grid mapping
+
+ const std::size_t position; // fixed index, negative too for diagonal iterator
+ std::size_t index; // variable index
+};
+
+template<typename T, std::size_t Rows, std::size_t Cols, class Grid>
+class mm::iter::basic_iterator : public mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid>
+{
+ bool direction;
+
+public:
+
+ basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true)
+ : mm::iter::vector_iterator<T, Rows, Cols, mm::iter::basic_iterator<T, Rows, Cols, Grid>, Grid>
+ (A, pos, _index), direction(dir)
+ {
+ if (direction)
+ assert(pos < Rows);
+ else
+ assert(pos < Cols);
+ }
+
+ virtual bool ok() const override
+ {
+ return (direction) ? this->index < Cols : this->index < Rows;
+ }
+
+ virtual T& operator*() override
+ {
+ return (direction) ?
+ this->M.data[this->position * Cols + this->index] :
+ this->M.data[this->index * Cols + this->position];
+
+ }
+
+ virtual T& operator[](std::size_t i) override
+ {
+ return (direction) ?
+ this->M.data[this->position * Cols + i] :
+ this->M.data[i * Cols + this->position];
+ }
+
+ virtual mm::iter::basic_iterator<T, Rows, Cols, Grid> end()
+ {
+ return mm::iter::basic_iterator<T, Rows, Cols, Grid>(this->M, this->position,
+ (direction) ? Cols : Rows);
+ }
+};
+
+template<typename T, std::size_t N, class Grid>
+class mm::iter::diag_iterator : public mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid>
+{
+ bool sign;
+
+public:
+
+ diag_iterator(Grid& A, signed long pos, std::size_t _index = 0)
+ : mm::iter::vector_iterator<T, N, N, mm::iter::diag_iterator<T, N, Grid>, Grid>
+ (A, static_cast<std::size_t>(abs(pos)), _index), sign(pos >= 0)
+ {
+ assert(this->position < N);
+ }
+
+ virtual bool ok() const
+ {
+ return this->index < N;
+ }
+
+ virtual T& operator*() override
+ {
+ return (sign) ?
+ this->M.data[(this->index - this->position) * N + this->index] :
+ this->M.data[this->index * N + (this->index + this->position)];
+ }
+
+ virtual T& operator[](std::size_t i) override
+ {
+ return (sign) ?
+ this->M.data[(i - this->position) * N + i] :
+ this->M.data[i * N + (i + this->position)];
+ }
+
+ virtual mm::iter::diag_iterator<T, N, Grid> end()
+ {
+ return mm::iter::diag_iterator<T, N, Grid>(this->M, this->position, N);
+ }
+};
+
+