From 67eaeee2e05924e0b73b0b1dce1cb88d55befb2c Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Mon, 13 Feb 2023 16:57:56 +0100
Subject: to_matrix_repr operator properly sorts the variables

---
 polymatrix/__init__.py                                    | 10 +++++++---
 polymatrix/expression/mixins/tosortedvariablesmixin.py    |  6 +++++-
 polymatrix/expression/utils/getvariableindices.py         |  6 ++++--
 polymatrix/expressionstate/mixins/expressionstatemixin.py |  6 ++++++
 4 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index de3f9f3..bc2f24b 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -458,10 +458,14 @@ def to_matrix_repr(
             initial=(state, tuple()),
         ))
 
-        state, ordered_variable_index = get_variable_indices_from_variable(state, variables)
+        state, variable_index = get_variable_indices_from_variable(state, variables)
 
-        assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+        tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
+
+        ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
 
+        assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+        
         variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
 
         n_param = len(ordered_variable_index)
@@ -483,7 +487,7 @@ def to_matrix_repr(
                         continue
                     
                     for monomial, value in underlying_terms.items():
-                        
+
                         def gen_new_monomial():
                             for var, count in monomial:
                                 try:
diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py
index 441634b..b4c8d19 100644
--- a/polymatrix/expression/mixins/tosortedvariablesmixin.py
+++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py
@@ -21,8 +21,12 @@ class ToSortedVariablesMixin(ExpressionBaseMixin):
     ) -> tuple[ExpressionState, PolyMatrix]:
         state, variable_indices = get_variable_indices_from_variable(state, self.underlying)
 
+        tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_indices)
+
+        ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+
         def gen_sorted_vector():
-            for row, index in enumerate(sorted(variable_indices)):
+            for row, index in enumerate(ordered_variable_index):
                 yield (row, 0), {((index, 1),): 1}
         
         poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 901c567..61bae2a 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,9 +1,9 @@
 import itertools
-import typing
+
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 
 
-def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple[int]]:
+def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
     
     if isinstance(variable, ExpressionBaseMixin):
         state, variable_polynomial = variable.apply(state)
@@ -29,9 +29,11 @@ def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple
         variable_indices = tuple(gen_variables_indices())
 
     elif isinstance(variable, int):
+        # raise Exception(f'{variable=}')
         variable_indices = (variable,)
 
     elif variable in state.offset_dict:
+        # raise Exception(f'{variable=}')
         variable_indices = (state.offset_dict[variable][0],)
 
     else:
diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py
index 18dba05..6268d4b 100644
--- a/polymatrix/expressionstate/mixins/expressionstatemixin.py
+++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py
@@ -36,6 +36,12 @@ class ExpressionStateMixin(
     def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:        
         ...
 
+    def get_name_from_offset(self, offset: int):
+        for variable, (start, end) in self.offset_dict.items():
+            if start <= offset < end:
+                return f'{str(variable)}_{offset-start}'
+
+
     def get_key_from_offset(self, offset: int):
         for variable, (start, end) in self.offset_dict.items():
             if start <= offset < end:
-- 
cgit v1.2.1