summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
blob: ff162755cde6923f6d4d6c69a1421f5ffeb4857a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import abc
import dataclasses

from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices


class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
    """
    Maps a polynomial matrix

        underlying = [
            [x y    ],
            [x + x^2],
        ]

    into a vector of monomials

        output = [1, x, y]

    in variable

        variables = [x, y].
    """

    @property
    @abc.abstractclassmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    @property
    @abc.abstractmethod
    def variables(self) -> ExpressionBaseMixin:
        ...

    # overwrites the abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionStateMixin,
    ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:

        state, underlying = self.underlying.apply(state=state)
        state, variable_indices = get_variable_indices_from_variable(state, self.variables)

        def gen_sos_monomials():
            for row in range(underlying.shape[0]):
                for col in range(underlying.shape[1]):
                    
                    polynomial = underlying.get_poly(row, col)
                    if polynomial is None:
                        continue

                    for monomial in polynomial.keys():
                        x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices)

                        left, right = split_monomial_indices(x_monomial)

                        yield left
                        yield right

        sos_monomials = tuple(sorted(set(gen_sos_monomials()), key=lambda m: (len(m), m)))

        def gen_terms():
            for index, monomial in enumerate(sos_monomials):
                yield (index, 0), {monomial: 1.0}

        terms = dict(gen_terms())

        poly_matrix = init_poly_matrix(
            terms=terms,
            shape=(len(sos_monomials), 1),
        )

        state = dataclasses.replace(
            state, 
            cache=state.cache | {self: poly_matrix},
        )

        return state, poly_matrix