diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-02-15 09:29:36 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-02-15 09:29:36 +0100 |
commit | 929a36f2c6d5118afdc64aa5c32222f9162df2fb (patch) | |
tree | 68ec3250c4c4d4e8b7fc39668c17a99371f5b447 | |
parent | to_matrix_repr operator properly sorts the variables (diff) | |
download | polymatrix-929a36f2c6d5118afdc64aa5c32222f9162df2fb.tar.gz polymatrix-929a36f2c6d5118afdc64aa5c32222f9162df2fb.zip |
make sorting optional for to_matrix_repr
-rw-r--r-- | polymatrix/__init__.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index bc2f24b..23fe913 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -434,6 +434,7 @@ class MatrixRepresentations: def to_matrix_repr( expressions: Expression | tuple[Expression], variables: Expression, + sorted: bool = None, ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: if isinstance(expressions, Expression): @@ -460,15 +461,19 @@ def to_matrix_repr( state, variable_index = get_variable_indices_from_variable(state, variables) - tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) + if sorted: + tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) - ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + + else: + sorted_variable_index = variable_index - assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' + assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables' - variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} + variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)} - n_param = len(ordered_variable_index) + n_param = len(sorted_variable_index) def gen_underlying_matrices(): for underlying in underlying_list: @@ -513,7 +518,7 @@ def to_matrix_repr( def gen_auxillary_equations(): for key, monomial_terms in state.auxillary_equations.items(): - if key in ordered_variable_index: + if key in sorted_variable_index: yield key, monomial_terms auxillary_equations = tuple(gen_auxillary_equations()) @@ -546,7 +551,7 @@ def to_matrix_repr( result = MatrixRepresentations( data=underlying_matrices, aux_data=auxillary_matrix_equations, - variable_mapping=ordered_variable_index, + variable_mapping=sorted_variable_index, state=state, ) |