1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
|
import abc
import collections
import typing
import dataclass_abc
from numpy import var
from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.polymatrixexprstate import PolyMatrixExprState
class DerivativeExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
...
@property
@abc.abstractmethod
def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
...
@property
@abc.abstractmethod
def introduce_derivatives(self) -> bool:
...
# overwrites abstract method of `ExpressionBaseMixin`
@property
def shape(self) -> tuple[int, int]:
match self.variables:
case ExpressionBaseMixin():
n_cols = self.variables.shape[0]
case _:
n_cols = len(self.variables)
return self.underlying.shape[0], n_cols
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
state = [state]
state[0], underlying = self.underlying.apply(state=state[0])
match self.variables:
case ExpressionBaseMixin():
assert self.variables.shape[1] == 1
state[0], variables = self.variables.apply(state[0])
def gen_indices():
for row in range(variables.shape[0]):
for monomial in variables.get_poly(row, 0).keys():
yield monomial[0]
variable_indices = tuple(sorted(gen_indices()))
# print(f'{variable_indices=}')
case _:
def gen_indices():
for variable in self.variable:
if variable in state[0].offset_dict:
yield state[0].offset_dict[variable][0]
variable_indices = tuple(sorted(gen_indices()))
terms = {}
# aux_terms = []
for var_idx, variable in enumerate(variable_indices):
def get_derivative_terms(monomial_terms):
terms_row_col = {}
for monomial, value in monomial_terms.items():
# count powers for each variable
monomial_cnt = dict(collections.Counter(monomial))
if variable not in monomial_cnt:
continue
if self.introduce_derivatives:
variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
else:
variable_candidates = (variable,)
for variable_candidate in variable_candidates:
def generate_monomial():
for current_variable, current_count in monomial_cnt.items():
if current_variable is variable_candidate:
sel_counter = current_count - 1
else:
sel_counter = current_count
for _ in range(sel_counter):
yield current_variable
if variable_candidate is not variable:
key = init_derivative_key(
variable=variable_candidate,
with_respect_to=var_idx,
)
state[0] = state[0].register(key=key, n_param=1)
yield state[0].offset_dict[key][0]
col_monomial = tuple(generate_monomial())
if col_monomial not in terms_row_col:
terms_row_col[col_monomial] = 0
terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
return terms_row_col
for row in range(self.shape[0]):
try:
underlying_terms = underlying.get_poly(row, 0)
except KeyError:
continue
derivative_terms = get_derivative_terms(underlying_terms)
if 0 < len(derivative_terms):
terms[row, var_idx] = derivative_terms
# if self.introduce_derivatives:
# for aux_monomial in underlying.aux_terms:
# derivative_terms = get_derivative_terms(aux_monomial)
# # todo: is this correct?
# if 1 < len(derivative_terms):
# aux_terms.append(derivative_terms)
poly_matrix = init_poly_matrix(
terms=terms,
shape=self.shape,
# aux_terms=underlying.aux_terms + tuple(aux_terms),
)
return state[0], poly_matrix
|