summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index bc2f24b..23fe913 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -434,6 +434,7 @@ class MatrixRepresentations:
def to_matrix_repr(
expressions: Expression | tuple[Expression],
variables: Expression,
+ sorted: bool = None,
) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
if isinstance(expressions, Expression):
@@ -460,15 +461,19 @@ def to_matrix_repr(
state, variable_index = get_variable_indices_from_variable(state, variables)
- tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
+ if sorted:
+ tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
- ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+ sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+
+ else:
+ sorted_variable_index = variable_index
- assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+ assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables'
- variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
+ variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)}
- n_param = len(ordered_variable_index)
+ n_param = len(sorted_variable_index)
def gen_underlying_matrices():
for underlying in underlying_list:
@@ -513,7 +518,7 @@ def to_matrix_repr(
def gen_auxillary_equations():
for key, monomial_terms in state.auxillary_equations.items():
- if key in ordered_variable_index:
+ if key in sorted_variable_index:
yield key, monomial_terms
auxillary_equations = tuple(gen_auxillary_equations())
@@ -546,7 +551,7 @@ def to_matrix_repr(
result = MatrixRepresentations(
data=underlying_matrices,
aux_data=auxillary_matrix_equations,
- variable_mapping=ordered_variable_index,
+ variable_mapping=sorted_variable_index,
state=state,
)