diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/__init__.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/mixins/tosortedvariablesmixin.py | 6 | ||||
-rw-r--r-- | polymatrix/expression/utils/getvariableindices.py | 6 | ||||
-rw-r--r-- | polymatrix/expressionstate/mixins/expressionstatemixin.py | 6 |
4 files changed, 22 insertions, 6 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index de3f9f3..bc2f24b 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -458,10 +458,14 @@ def to_matrix_repr( initial=(state, tuple()), )) - state, ordered_variable_index = get_variable_indices_from_variable(state, variables) + state, variable_index = get_variable_indices_from_variable(state, variables) - assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' + tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) + + ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' + variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} n_param = len(ordered_variable_index) @@ -483,7 +487,7 @@ def to_matrix_repr( continue for monomial, value in underlying_terms.items(): - + def gen_new_monomial(): for var, count in monomial: try: diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index 441634b..b4c8d19 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -21,8 +21,12 @@ class ToSortedVariablesMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, variable_indices = get_variable_indices_from_variable(state, self.underlying) + tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_indices) + + ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + def gen_sorted_vector(): - for row, index in enumerate(sorted(variable_indices)): + for row, index in enumerate(ordered_variable_index): yield (row, 0), {((index, 1),): 1} poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 901c567..61bae2a 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,9 +1,9 @@ import itertools -import typing + from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple[int]]: +def get_variable_indices_from_variable(state, variable) -> tuple[int] | None: if isinstance(variable, ExpressionBaseMixin): state, variable_polynomial = variable.apply(state) @@ -29,9 +29,11 @@ def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple variable_indices = tuple(gen_variables_indices()) elif isinstance(variable, int): + # raise Exception(f'{variable=}') variable_indices = (variable,) elif variable in state.offset_dict: + # raise Exception(f'{variable=}') variable_indices = (state.offset_dict[variable][0],) else: diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py index 18dba05..6268d4b 100644 --- a/polymatrix/expressionstate/mixins/expressionstatemixin.py +++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py @@ -36,6 +36,12 @@ class ExpressionStateMixin( def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: ... + def get_name_from_offset(self, offset: int): + for variable, (start, end) in self.offset_dict.items(): + if start <= offset < end: + return f'{str(variable)}_{offset-start}' + + def get_key_from_offset(self, offset: int): for variable, (start, end) in self.offset_dict.items(): if start <= offset < end: |