summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/utils/getvariableindices.py55
1 files changed, 24 insertions, 31 deletions
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 9aa3619..c72bb02 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -3,43 +3,36 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
def get_variable_indices(state, variables):
- # print(f'{variables=}')
+ global_state = [state]
- if isinstance(variables, ExpressionBaseMixin):
- state, variables = variables.apply(state)
+ if not isinstance(variables, tuple):
+ variables = (variables,)
- assert variables.shape[1] == 1
+ def gen_indices():
+ for variable in variables:
+ if isinstance(variable, ExpressionBaseMixin):
+ global_state[0], variable_polynomial = variable.apply(global_state[0])
- def gen_indices():
- for row in range(variables.shape[0]):
- row_terms = variables.get_poly(row, 0)
+ assert variable_polynomial.shape[1] == 1
- assert len(row_terms) == 1, f'{row_terms} contains more than one term'
-
- for monomial in row_terms.keys():
- assert len(monomial) == 1, f'{monomial=} contains more than one variable'
- assert monomial[0][1] == 1, f'{monomial[0]=}'
+ for row in range(variable_polynomial.shape[0]):
+ row_terms = variable_polynomial.get_poly(row, 0)
- yield monomial[0][0]
+ assert len(row_terms) == 1, f'{row_terms} contains more than one term'
+
+ for monomial in row_terms.keys():
+ assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
- return state, tuple(gen_indices())
+ if len(monomial) == 0:
+ continue
+
+ assert monomial[0][1] == 1, f'{monomial[0]=}'
+ yield monomial[0][0]
- else:
+ elif isinstance(variable, int):
+ yield variable
- # raise Exception('not supported anymore')
+ else:
+ state.offset_dict[variable][0]
- if not isinstance(variables, tuple):
- variables = (variables,)
-
- # assert all(isinstance(variable, type(variables[0])) for variable in variables)
-
- def gen_indices():
- for variable in variables:
-
- if isinstance(variable, int):
- yield variable
-
- else:
- yield state.offset_dict[variable][0]
-
- return state, tuple(gen_indices())
+ return state, tuple(gen_indices())