summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/repmatexprmixin.py
blob: 48748beb60766cffab6d155dfe74eeb3f2b91128 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import abc
import dataclassabc

from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin

class RepMatExprMixin(ExpressionBaseMixin):
    @property
    @abc.abstractclassmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    @property
    @abc.abstractclassmethod
    def repetition(self) -> tuple[int, int]:
        ...

    # overwrites the abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionStateMixin,
    ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:

        state, underlying = self.underlying.apply(state)

        @dataclassabc.dataclassabc(frozen=True)
        class RepMatPolyMatrix(PolyMatrixMixin):
            underlying: PolyMatrixMixin
            shape: tuple[int, int]

            def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
                n_row, n_col = underlying.shape

                rel_row = row % n_row
                rel_col = col % n_col

                return self.underlying.get_poly(rel_row, rel_col)

        return state, RepMatPolyMatrix(
            underlying=underlying,
            shape=tuple(s*r for s, r in zip(underlying.shape, self.repetition)),
        )