1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
import abc
import collections
import itertools
import math
from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getvariableindices import get_variable_indices
from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
class SubstituteExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
...
@property
@abc.abstractmethod
def variables(self) -> tuple:
...
@property
@abc.abstractmethod
def substitutions(self) -> tuple[ExpressionBaseMixin, ...]:
...
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
state, variable_indices = get_variable_indices(state, self.variables)
def acc_substitutions(acc, substitution_expr):
state, result = acc
# for expr in self.expressions:
if isinstance(substitution_expr, ExpressionBaseMixin):
state, substitution = substitution_expr.apply(state)
assert substitution.shape == (1, 1), f'{substitution=}'
polynomial = substitution.get_poly(0, 0)
elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float):
polynomial = {tuple(): substitution_expr}
else:
raise Exception(f'{substitution_expr=} not recognized')
return state, result + (polynomial,)
*_, (state, substitutions) = tuple(itertools.accumulate(
self.substitutions,
acc_substitutions,
initial=(state, tuple()),
))
terms = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
polynomial = underlying.get_poly(row, col)
if polynomial is None:
continue
terms_row_col = collections.defaultdict(float)
for monomial, value in polynomial.items():
terms_monomial = {tuple(): value}
for variable, count in monomial:
if variable in variable_indices:
index = variable_indices.index(variable)
substitution = substitutions[index]
for _ in range(count):
next = {}
multiply_polynomial(terms_monomial, substitution, next)
terms_monomial = next
else:
next = {}
multiply_polynomial(terms_monomial, {((variable, count),): 1.0}, next)
terms_monomial = next
for monomial, value in terms_monomial.items():
terms_row_col[monomial] += value
terms_row_col = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)}
if 0 < len(terms_row_col):
terms[row, col] = terms_row_col
poly_matrix = init_poly_matrix(
terms=terms,
shape=underlying.shape,
)
return state, poly_matrix
|