summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/tosortedvariablesmixin.py
blob: a61a7e9f3cff7e6408081558f1518754f73b9a16 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import abc
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable

from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState


# to be deleted?
class ToSortedVariablesExprMixin(ExpressionBaseMixin):
    @property
    @abc.abstractmethod
    def underlying(self) -> ExpressionBaseMixin:
        ...

    # overwrites the abstract method of `ExpressionBaseMixin`
    def apply(
        self, 
        state: ExpressionState,
    ) -> tuple[ExpressionState, PolyMatrix]:
        state, variable_indices = get_variable_indices_from_variable(state, self.underlying)

        tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_indices)

        ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))

        def gen_sorted_vector():
            for row, index in enumerate(ordered_variable_index):
                yield (row, 0), {((index, 1),): 1}
        
        poly_matrix = init_poly_matrix(
            terms=dict(gen_sorted_vector()),
            shape=(len(variable_indices), 1),
        )

        return state, poly_matrix