1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
import abc
import typing
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
class DerivativeExprMixin(ExpressionBaseMixin):
"""
differentiate w.r.t. x:
[[x**2]] -> [[2*x]]
introduce_derivatives: not used at the moment
"""
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
...
@property
@abc.abstractmethod
def variables(self) -> ExpressionBaseMixin:
...
@property
@abc.abstractmethod
def introduce_derivatives(self) -> bool:
...
@property
@abc.abstractmethod
def stack(self) -> tuple[FrameSummary]:
...
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
state, diff_wrt_variables = get_variable_indices_from_variable(state, self.variables)
if not (underlying.shape[1] == 1):
raise AssertionError(to_operator_exception(
message=f'{underlying.shape[1]=} is not 1',
stack=self.stack,
))
terms = {}
for row in range(underlying.shape[0]):
underlying_terms = underlying.get_poly(row, 0)
if underlying_terms is None:
continue
# derivate each variable and map result to the corresponding column
for col, diff_wrt_variable in enumerate(diff_wrt_variables):
state, derivation_terms = get_derivative_monomials(
monomial_terms=underlying_terms,
diff_wrt_variable=diff_wrt_variable,
state=state,
considered_variables=set(diff_wrt_variables),
introduce_derivatives=self.introduce_derivatives,
)
if 0 < len(derivation_terms):
terms[row, col] = derivation_terms
poly_matrix = init_poly_matrix(
terms=terms,
shape=(underlying.shape[0], len(diff_wrt_variables)),
)
return state, poly_matrix
|