summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
blob: 87d5f5bff56ea16c22694cfbf7145bb028f6778c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import abc
import typing

from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
from polymatrix.expression.utils.getvariableindices import get_variable_indices


class DerivativeExprMixin(ExpressionBaseMixin):
    @property
    @abc.abstractmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    @property
    @abc.abstractmethod
    def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
        ...

    @property
    @abc.abstractmethod
    def introduce_derivatives(self) -> bool:
        ...

    # overwrites abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionState,
    ) -> tuple[ExpressionState, PolyMatrix]:

        state, underlying = self.underlying.apply(state=state)        
        state, diff_wrt_variables = get_variable_indices(state, self.variables)

        assert underlying.shape[1] == 1, f'{underlying.shape=}'

        terms = {}

        for row in range(underlying.shape[0]):

            try:
                underlying_terms = underlying.get_poly(row, 0)
            except KeyError:
                continue

            # derivate each variable and map result to the corresponding column
            for col, diff_wrt_variable in enumerate(diff_wrt_variables):

                state, derivation_terms = get_derivative_monomials(
                    monomial_terms=underlying_terms,
                    diff_wrt_variable=diff_wrt_variable,
                    state=state,
                    considered_variables=set(),
                    introduce_derivatives=self.introduce_derivatives,
                )

                if 0 < len(derivation_terms):
                    terms[row, col] = derivation_terms

        poly_matrix = init_poly_matrix(
            terms=terms,
            shape=(underlying.shape[0], len(diff_wrt_variables)),
        )

        return state, poly_matrix