summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py10
-rw-r--r--polymatrix/expression/mixins/tosortedvariablesmixin.py6
-rw-r--r--polymatrix/expression/utils/getvariableindices.py6
-rw-r--r--polymatrix/expressionstate/mixins/expressionstatemixin.py6
4 files changed, 22 insertions, 6 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index de3f9f3..bc2f24b 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -458,10 +458,14 @@ def to_matrix_repr(
initial=(state, tuple()),
))
- state, ordered_variable_index = get_variable_indices_from_variable(state, variables)
+ state, variable_index = get_variable_indices_from_variable(state, variables)
- assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+ tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
+
+ ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+ assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+
variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
n_param = len(ordered_variable_index)
@@ -483,7 +487,7 @@ def to_matrix_repr(
continue
for monomial, value in underlying_terms.items():
-
+
def gen_new_monomial():
for var, count in monomial:
try:
diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py
index 441634b..b4c8d19 100644
--- a/polymatrix/expression/mixins/tosortedvariablesmixin.py
+++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py
@@ -21,8 +21,12 @@ class ToSortedVariablesMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, variable_indices = get_variable_indices_from_variable(state, self.underlying)
+ tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_indices)
+
+ ordered_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+
def gen_sorted_vector():
- for row, index in enumerate(sorted(variable_indices)):
+ for row, index in enumerate(ordered_variable_index):
yield (row, 0), {((index, 1),): 1}
poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 901c567..61bae2a 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,9 +1,9 @@
import itertools
-import typing
+
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple[int]]:
+def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
if isinstance(variable, ExpressionBaseMixin):
state, variable_polynomial = variable.apply(state)
@@ -29,9 +29,11 @@ def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple
variable_indices = tuple(gen_variables_indices())
elif isinstance(variable, int):
+ # raise Exception(f'{variable=}')
variable_indices = (variable,)
elif variable in state.offset_dict:
+ # raise Exception(f'{variable=}')
variable_indices = (state.offset_dict[variable][0],)
else:
diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py
index 18dba05..6268d4b 100644
--- a/polymatrix/expressionstate/mixins/expressionstatemixin.py
+++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py
@@ -36,6 +36,12 @@ class ExpressionStateMixin(
def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
...
+ def get_name_from_offset(self, offset: int):
+ for variable, (start, end) in self.offset_dict.items():
+ if start <= offset < end:
+ return f'{str(variable)}_{offset-start}'
+
+
def get_key_from_offset(self, offset: int):
for variable, (start, end) in self.offset_dict.items():
if start <= offset < end: