1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
import itertools
import typing
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
# FIXME: typing
def get_variable_indices_from_variable(
state: ExpressionState,
variable: ExpressionBaseMixin | int | typing.Any,
) -> tuple[int, ...] | None:
if isinstance(variable, ExpressionBaseMixin):
state, variable_polynomial = variable.apply(state)
assert variable_polynomial.shape[1] == 1
def gen_variables_indices():
for row in range(variable_polynomial.shape[0]):
row_terms = variable_polynomial.get_poly(row, 0)
assert (
len(row_terms) == 1
), f"{row_terms} does not contain a single term"
for monomial in row_terms.keys():
assert (
len(monomial) <= 1
), f"{monomial=} contains more than one variable"
if len(monomial) == 0:
continue
assert monomial[0][1] == 1, f"{monomial[0]=}"
yield monomial[0][0]
variable_indices = tuple(gen_variables_indices())
elif isinstance(variable, int):
variable_indices = (variable,)
elif variable in state.offset_dict:
variable_indices = tuple(range(*state.offset_dict[variable]))
else:
variable_indices = None
return state, variable_indices
# not used, remove?
def get_variable_indices(state, variables):
if not isinstance(variables, tuple):
variables = (variables,)
def acc_variable_indices(acc, variable):
state, indices = acc
state, new_indices = get_variable_indices_from_variable(state, variable)
return state, indices + new_indices
*_, (state, indices) = itertools.accumulate(
variables,
acc_variable_indices,
initial=(state, tuple()),
)
return state, indices
|