diff options
-rw-r--r-- | polymatrix/__init__.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index bc2f24b..23fe913 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -434,6 +434,7 @@ class MatrixRepresentations: def to_matrix_repr( expressions: Expression | tuple[Expression], variables: Expression, + sorted: bool = None, ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: if isinstance(expressions, Expression): @@ -460,15 +461,19 @@ def to_matrix_repr( state, variable_index = get_variable_indices_from_variable(state, variables) - tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) + if sorted: + tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) - ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + + else: + sorted_variable_index = variable_index - assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' + assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables' - variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} + variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)} - n_param = len(ordered_variable_index) + n_param = len(sorted_variable_index) def gen_underlying_matrices(): for underlying in underlying_list: @@ -513,7 +518,7 @@ def to_matrix_repr( def gen_auxillary_equations(): for key, monomial_terms in state.auxillary_equations.items(): - if key in ordered_variable_index: + if key in sorted_variable_index: yield key, monomial_terms auxillary_equations = tuple(gen_auxillary_equations()) @@ -546,7 +551,7 @@ def to_matrix_repr( result = MatrixRepresentations( data=underlying_matrices, aux_data=auxillary_matrix_equations, - variable_mapping=ordered_variable_index, + variable_mapping=sorted_variable_index, state=state, ) |