summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
blob: 301166a240dd6a17acf342a187b02d1f6c44b771 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import abc
import typing

if typing.TYPE_CHECKING:
    from polymatrix.expressionstate.abc import ExpressionState

from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial
from polymatrix.expression.utils.getvariableindices import (
    get_variable_indices_from_variable,
)
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception


class DerivativeExprMixin(ExpressionBaseMixin):
    """
    differentiate w.r.t. x:

    [[x**2]]  ->  [[2*x]]

    introduce_derivatives: not used at the moment
    """

    @property
    @abc.abstractmethod
    def underlying(self) -> ExpressionBaseMixin: ...

    @property
    @abc.abstractmethod
    def variables(self) -> ExpressionBaseMixin: ...

    @property
    @abc.abstractmethod
    def introduce_derivatives(self) -> bool: ...

    @property
    @abc.abstractmethod
    def stack(self) -> tuple[FrameSummary]: ...

    # overwrites the abstract method of `ExpressionBaseMixin`
    def apply(
        self,
        state: ExpressionState,
    ) -> tuple[ExpressionState, PolyMatrix]:
        state, underlying = self.underlying.apply(state=state)
        state, variables = get_variable_indices_from_variable(state, self.variables)

        if not (underlying.shape[1] == 1):
            raise AssertionError(
                to_operator_exception(
                    message=f"{underlying.shape[1]=} is not 1",
                    stack=self.stack,
                )
            )

        poly_matrix_data = {}

        for row in range(underlying.shape[0]):
            underlying_poly = underlying.get_poly(row, 0)

            if underlying_poly is None:
                continue

            # derivate each variable and map result to the corresponding column
            for col, variable in enumerate(variables):
                state, diff_polynomial = differentiate_polynomial(
                    polynomial=underlying_poly,
                    diff_wrt_variable=variable,
                    state=state,
                    considered_variables=set(variables),
                    introduce_derivatives=self.introduce_derivatives,
                )

                if 0 < len(diff_polynomial):
                    poly_matrix_data[row, col] = diff_polynomial

        poly_matrix = init_poly_matrix(
            data=poly_matrix_data,
            shape=(underlying.shape[0], len(variables)),
        )

        return state, poly_matrix