From 929a36f2c6d5118afdc64aa5c32222f9162df2fb Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Wed, 15 Feb 2023 09:29:36 +0100
Subject: make sorting optional for to_matrix_repr

---
 polymatrix/__init__.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index bc2f24b..23fe913 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -434,6 +434,7 @@ class MatrixRepresentations:
 def to_matrix_repr(
     expressions: Expression | tuple[Expression],
     variables: Expression,
+    sorted: bool = None,
 ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
 
     if isinstance(expressions, Expression):
@@ -460,15 +461,19 @@ def to_matrix_repr(
 
         state, variable_index = get_variable_indices_from_variable(state, variables)
 
-        tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
+        if sorted:
+            tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
 
-        ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+            sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+            
+        else:
+            sorted_variable_index = variable_index
 
-        assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+        assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables'
         
-        variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
+        variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)}
 
-        n_param = len(ordered_variable_index)
+        n_param = len(sorted_variable_index)
 
         def gen_underlying_matrices():
             for underlying in underlying_list:
@@ -513,7 +518,7 @@ def to_matrix_repr(
 
         def gen_auxillary_equations():
             for key, monomial_terms in state.auxillary_equations.items():
-                if key in ordered_variable_index:
+                if key in sorted_variable_index:
                     yield key, monomial_terms
 
         auxillary_equations = tuple(gen_auxillary_equations())
@@ -546,7 +551,7 @@ def to_matrix_repr(
         result = MatrixRepresentations(
             data=underlying_matrices,
             aux_data=auxillary_matrix_equations,
-            variable_mapping=ordered_variable_index,
+            variable_mapping=sorted_variable_index,
             state=state,
         )
 
-- 
cgit v1.2.1