#pragma once #include "debug.hpp" namespace mm::iter { template class vector_iterator; template class basic_iterator; template class diag_iterator; } template class mm::iter::vector_iterator { public: template friend class mm::iter::basic_iterator; template friend class mm::iter::diag_iterator; vector_iterator(Grid& _M, std::size_t pos, std::size_t i = 0) : M(_M), position(pos), index(i) {} #ifdef MM_IMPLICIT_CONVERSION_ITERATOR operator T&() { npdebug("Calling +") return *(*this); } #endif IterType operator++() { IterType it = cpy(); ++index; return it; } IterType operator--() { IterType it = cpy(); --index; return it; } IterType& operator++(int) { ++index; return ref(); } IterType& operator--(int) { --index; return ref(); } bool operator==(const IterType& other) const { return index == other.index; } bool operator!=(const IterType& other) const { return index != other.index; } bool ok() const { return index < size(); } virtual std::size_t size() const = 0; virtual T& operator*() = 0; virtual T& operator[](std::size_t) = 0; virtual T& operator[](std::size_t) const = 0; IterType begin() { return IterType(M, position, 0); } virtual IterType end() = 0; protected: Grid& M; // grid mapping const std::size_t position; // fixed index, negative too for diagonal iterator std::size_t index; // variable index virtual IterType& ref() = 0; virtual IterType cpy() = 0; }; /* * Scalar product */ template typename std::remove_const::type operator*(const mm::iter::vector_iterator& v, const mm::iter::vector_iterator& w) { typename std::remove_const::type out(0); const std::size_t N = std::min(v.size(), w.size()); for(unsigned i = 0; i < N; ++i) out += v[i] * w[i]; return out; } template class mm::iter::basic_iterator : public mm::iter::vector_iterator, Grid> { bool direction; virtual mm::iter::basic_iterator& ref() override { return *this; } virtual mm::iter::basic_iterator cpy() override { return *this; } public: basic_iterator(Grid& A, std::size_t pos, std::size_t _index = 0, bool dir = true) : mm::iter::vector_iterator, Grid> (A, pos, _index), direction(dir) { //npdebug("Position: ", pos, ", Rows: ", Rows, " Cols: ", Cols, ", Direction: ", dir) if (direction) assert(pos < Rows); else assert(pos < Cols); } virtual std::size_t size() const { return (direction) ? Cols : Rows; } virtual T& operator*() override { return (direction) ? this->M.data[this->position * Cols + this->index] : this->M.data[this->index * Cols + this->position]; } virtual T& operator[](std::size_t i) override { return (direction) ? this->M.data[this->position * Cols + i] : this->M.data[i * Cols + this->position]; } virtual T& operator[](std::size_t i) const override { return (direction) ? this->M.data[this->position * Cols + i] : this->M.data[i * Cols + this->position]; } virtual mm::iter::basic_iterator end() { return mm::iter::basic_iterator(this->M, this->position, (direction) ? Cols : Rows); } }; template class mm::iter::diag_iterator : public mm::iter::vector_iterator, Grid> { bool sign; virtual mm::iter::diag_iterator& ref() override { return *this; } virtual mm::iter::diag_iterator cpy() override { return *this; } public: diag_iterator(Grid& A, signed long int pos, std::size_t _index = 0) : mm::iter::vector_iterator, Grid> (A, static_cast(labs(pos)), _index), sign(pos >= 0) { assert(this->position < N); } virtual std::size_t size() const { return N - this->position; } virtual T& operator*() override { return (sign) ? this->M.data[(this->index - this->position) * N + this->index] : this->M.data[this->index * N + (this->index + this->position)]; } virtual T& operator[](std::size_t i) override { return (sign) ? this->M.data[(i - this->position) * N + i] : this->M.data[i * N + (i + this->position)]; } virtual T& operator[](std::size_t i) const override { return (sign) ? this->M.data[(i - this->position) * N + i] : this->M.data[i * N + (i + this->position)]; } virtual mm::iter::diag_iterator end() { return mm::iter::diag_iterator(this->M, this->position, N); } };