summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
blob: 516aa38ac8ad970b6e842a0dc19e39e7b5faedeb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import abc
import collections
import typing
import dataclass_abc
from numpy import var
from polymatrix.expression.init.initderivativekey import init_derivative_key

from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.polymatrixexprstate import PolyMatrixExprState


class DerivativeExprMixin(ExpressionBaseMixin):
    @property
    @abc.abstractmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    @property
    @abc.abstractmethod
    def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
        ...

    @property
    @abc.abstractmethod
    def introduce_derivatives(self) -> bool:
        ...

    # overwrites abstract method of `ExpressionBaseMixin`
    @property
    def shape(self) -> tuple[int, int]:
        match self.variables:
            case ExpressionBaseMixin():
                n_cols = self.variables.shape[0]

            case _:
                n_cols = len(self.variables)

        return self.underlying.shape[0], n_cols

    # overwrites abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: PolyMatrixExprState,
    ) -> tuple[PolyMatrixExprState, PolyMatrix]:
        state = [state]

        state[0], underlying = self.underlying.apply(state=state[0])

        match self.variables:
            case ExpressionBaseMixin():
                assert self.variables.shape[1] == 1

                state[0], variables = self.variables.apply(state[0])

                def gen_indices():
                    for row in range(variables.shape[0]):
                        for monomial in variables.get_poly(row, 0).keys():
                            yield monomial[0]

                variable_indices = tuple(sorted(gen_indices()))
                # print(f'{variable_indices=}')

            case _:
                def gen_indices():
                    for variable in self.variable:
                        if variable in state[0].offset_dict:
                            yield state[0].offset_dict[variable][0]

                variable_indices = tuple(sorted(gen_indices()))

        terms = {}
        # aux_terms = []

        for var_idx, variable in enumerate(variable_indices):
            
            def get_derivative_terms(monomial_terms):

                terms_row_col = {}

                for monomial, value in monomial_terms.items():

                    # count powers for each variable
                    monomial_cnt = dict(collections.Counter(monomial))

                    if variable not in monomial_cnt:
                        continue

                    if self.introduce_derivatives:
                        variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
                    else:
                        variable_candidates = (variable,)
                    
                    for variable_candidate in variable_candidates:

                        def generate_monomial():
                            for current_variable, current_count in monomial_cnt.items():

                                if current_variable is variable_candidate:
                                    sel_counter = current_count - 1

                                else:
                                    sel_counter = current_count

                                for _ in range(sel_counter):
                                    yield current_variable

                            if variable_candidate is not variable:
                                key = init_derivative_key(
                                    variable=variable_candidate,
                                    with_respect_to=var_idx,
                                )
                                state[0] = state[0].register(key=key, n_param=1)

                                yield state[0].offset_dict[key][0]

                        col_monomial = tuple(generate_monomial())

                        if col_monomial not in terms_row_col:
                            terms_row_col[col_monomial] = 0

                        terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]

                return terms_row_col

            for row in range(self.shape[0]):

                try:
                    underlying_terms = underlying.get_poly(row, 0)
                except KeyError:
                    continue

                derivative_terms = get_derivative_terms(underlying_terms)

                if 0 < len(derivative_terms):
                    terms[row, var_idx] = derivative_terms

            # if self.introduce_derivatives:
            #     for aux_monomial in underlying.aux_terms:

            #         derivative_terms = get_derivative_terms(aux_monomial)

            #         # todo: is this correct?
            #         if 1 < len(derivative_terms):
            #             aux_terms.append(derivative_terms)
        
        poly_matrix = init_poly_matrix(
            terms=terms,
            shape=self.shape,
            # aux_terms=underlying.aux_terms + tuple(aux_terms),
        )

        return state[0], poly_matrix