summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/substituteexprmixin.py
blob: 9741897b35266b9f69001a9b3a0ef9cdb0ebd113 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import abc
import collections
import itertools
import math

from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getvariableindices import get_variable_indices
from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial


class SubstituteExprMixin(ExpressionBaseMixin):
    @property
    @abc.abstractmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    @property
    @abc.abstractmethod
    def variables(self) -> tuple:
        ...

    @property
    @abc.abstractmethod
    def substitutions(self) -> tuple[ExpressionBaseMixin, ...]:
        ...

    # overwrites abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionState,
    ) -> tuple[ExpressionState, PolyMatrix]:
        state, underlying = self.underlying.apply(state=state)
        state, variable_indices = get_variable_indices(state, self.variables)

        def acc_substitutions(acc, substitution_expr):
            state, result = acc

            # for expr in self.expressions:
            if isinstance(substitution_expr, ExpressionBaseMixin):
                state, substitution = substitution_expr.apply(state)

                assert substitution.shape == (1, 1), f'{substitution=}'

                polynomial = substitution.get_poly(0, 0)

            elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float):
                polynomial = {tuple(): substitution_expr}

            else:
                raise Exception(f'{substitution_expr=} not recognized')

            return state, result + (polynomial,)

        *_, (state, substitutions) = tuple(itertools.accumulate(
            self.substitutions,
            acc_substitutions,
            initial=(state, tuple()),
        ))

        terms = {}

        for row in range(underlying.shape[0]):
            for col in range(underlying.shape[1]):
            
                polynomial = underlying.get_poly(row, col)
                if polynomial is None:
                    continue

                terms_row_col = collections.defaultdict(float)

                for monomial, value in polynomial.items():

                    terms_monomial = {tuple(): value}

                    for variable, count in monomial:
                        if variable in variable_indices:
                            index = variable_indices.index(variable)
                            substitution = substitutions[index]


                            for _ in range(count):
                        
                                next = {}
                                multiply_polynomial(terms_monomial, substitution, next)
                                terms_monomial = next

                        else:
                            next = {}
                            multiply_polynomial(terms_monomial, {((variable, count),): 1.0}, next)
                            terms_monomial = next

                    for monomial, value in terms_monomial.items():
                        terms_row_col[monomial] += value

                terms_row_col = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)}

                if 0 < len(terms_row_col):
                    terms[row, col] = terms_row_col

        poly_matrix = init_poly_matrix(
            terms=terms,
            shape=underlying.shape,
        )

        return state, poly_matrix