blob: e5aef11b9d96d218aabdf9603a2e30e1a561ed27 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
import abc
import dataclasses
from polymatrix.polymatrix.mixins import PolyMatrixAsDictMixin
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
class CacheExprMixin(ExpressionBaseMixin):
""" Caches the polynomial matrix using the state """
@property
@abc.abstractclassmethod
def underlying(self) -> ExpressionBaseMixin:
...
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
if self in state.cache:
return state, state.cache[self]
state, underlying = self.underlying.apply(state)
if isinstance(underlying, PolyMatrixAsDictMixin):
cached_terms = underlying.terms
else:
cached_terms = dict(underlying.gen_terms())
poly_matrix = init_poly_matrix(
terms=cached_terms,
shape=underlying.shape,
)
state = dataclasses.replace(
state,
cache=state.cache | {self: poly_matrix},
)
return state, poly_matrix
|