diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/utils/getvariableindices.py | 55 |
1 files changed, 24 insertions, 31 deletions
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 9aa3619..c72bb02 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -3,43 +3,36 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin def get_variable_indices(state, variables): - # print(f'{variables=}') + global_state = [state] - if isinstance(variables, ExpressionBaseMixin): - state, variables = variables.apply(state) + if not isinstance(variables, tuple): + variables = (variables,) - assert variables.shape[1] == 1 + def gen_indices(): + for variable in variables: + if isinstance(variable, ExpressionBaseMixin): + global_state[0], variable_polynomial = variable.apply(global_state[0]) - def gen_indices(): - for row in range(variables.shape[0]): - row_terms = variables.get_poly(row, 0) + assert variable_polynomial.shape[1] == 1 - assert len(row_terms) == 1, f'{row_terms} contains more than one term' - - for monomial in row_terms.keys(): - assert len(monomial) == 1, f'{monomial=} contains more than one variable' - assert monomial[0][1] == 1, f'{monomial[0]=}' + for row in range(variable_polynomial.shape[0]): + row_terms = variable_polynomial.get_poly(row, 0) - yield monomial[0][0] + assert len(row_terms) == 1, f'{row_terms} contains more than one term' + + for monomial in row_terms.keys(): + assert len(monomial) <= 1, f'{monomial=} contains more than one variable' - return state, tuple(gen_indices()) + if len(monomial) == 0: + continue + + assert monomial[0][1] == 1, f'{monomial[0]=}' + yield monomial[0][0] - else: + elif isinstance(variable, int): + yield variable - # raise Exception('not supported anymore') + else: + state.offset_dict[variable][0] - if not isinstance(variables, tuple): - variables = (variables,) - - # assert all(isinstance(variable, type(variables[0])) for variable in variables) - - def gen_indices(): - for variable in variables: - - if isinstance(variable, int): - yield variable - - else: - yield state.offset_dict[variable][0] - - return state, tuple(gen_indices()) + return state, tuple(gen_indices()) |