summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/cacheexprmixin.py
blob: e5aef11b9d96d218aabdf9603a2e30e1a561ed27 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import abc
import dataclasses
from polymatrix.polymatrix.mixins import PolyMatrixAsDictMixin

from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin


class CacheExprMixin(ExpressionBaseMixin):
    """ Caches the polynomial matrix using the state """
    
    @property
    @abc.abstractclassmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    # overwrites the abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionStateMixin,
    ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:

        if self in state.cache:
            return state, state.cache[self]

        state, underlying = self.underlying.apply(state)

        if isinstance(underlying, PolyMatrixAsDictMixin):
            cached_terms = underlying.terms
        else:
            cached_terms = dict(underlying.gen_terms())

        poly_matrix = init_poly_matrix(
            terms=cached_terms,
            shape=underlying.shape,
        )

        state = dataclasses.replace(
            state, 
            cache=state.cache | {self: poly_matrix},
        )

        return state, poly_matrix