summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/utils/getvariableindices.py
blob: 743f6b98065e983bf666872d0fda5241702a6c7c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import itertools
import typing

from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin


# FIXME: typing
def get_variable_indices_from_variable(
    state: ExpressionState,
    variable: ExpressionBaseMixin | int | typing.Any,
) -> tuple[int, ...] | None:
    if isinstance(variable, ExpressionBaseMixin):
        state, variable_polynomial = variable.apply(state)

        assert variable_polynomial.shape[1] == 1

        def gen_variables_indices():
            for row in range(variable_polynomial.shape[0]):
                row_terms = variable_polynomial.get_poly(row, 0)

                assert (
                    len(row_terms) == 1
                ), f"{row_terms} does not contain a single term"

                for monomial in row_terms.keys():
                    assert (
                        len(monomial) <= 1
                    ), f"{monomial=} contains more than one variable"

                    if len(monomial) == 0:
                        continue

                    assert monomial[0][1] == 1, f"{monomial[0]=}"
                    yield monomial[0][0]

        variable_indices = tuple(gen_variables_indices())

    elif isinstance(variable, int):
        variable_indices = (variable,)

    elif variable in state.offset_dict:
        variable_indices = tuple(range(*state.offset_dict[variable]))

    else:
        variable_indices = None

    return state, variable_indices


# not used, remove?
def get_variable_indices(state, variables):
    if not isinstance(variables, tuple):
        variables = (variables,)

    def acc_variable_indices(acc, variable):
        state, indices = acc

        state, new_indices = get_variable_indices_from_variable(state, variable)

        return state, indices + new_indices

    *_, (state, indices) = itertools.accumulate(
        variables,
        acc_variable_indices,
        initial=(state, tuple()),
    )

    return state, indices