summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/to.py83
1 files changed, 30 insertions, 53 deletions
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index 0cea1d9..2a43c0e 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -1,3 +1,4 @@
+import math
import sympy
import numpy as np
@@ -46,56 +47,32 @@ def to_constant(
def to_sympy(
expr: Expression,
-) -> StateMonadMixin[ExpressionState, sympy.Expr]:
-
- # TODO: fix this function with new expression state
- raise NotImplementedError("Broken because FromTupleExpr was removed")
-
- def func(state: ExpressionState):
- state, underlying = expr.apply(state)
-
- A = np.zeros(underlying.shape, dtype=object)
-
- for (row, col), polynomial in underlying.gen_data():
- sympy_polynomial = 0
-
- for monomial, value in polynomial.items():
- sympy_monomial = 1
-
- for offset, count in monomial:
- variable = state.get_key_from_offset(offset)
- # def get_variable_from_offset(offset: int):
- # for variable, (start, end) in state.offset_dict.items():
- # if start <= offset < end:
- # assert end - start == 1, f'{start=}, {end=}, {variable=}'
-
- if isinstance(variable, sympy.core.symbol.Symbol):
- variable_name = variable.name
- elif isinstance(
- variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)
- ):
- variable_name = variable.name
- elif isinstance(variable, str):
- variable_name = variable
- else:
- raise Exception(f"{variable=}")
-
- start, end = state.offset_dict[variable]
-
- if end - start == 1:
- sympy_var = sympy.Symbol(variable_name)
- else:
- sympy_var = sympy.Symbol(
- f"{variable_name}_{offset - start + 1}"
- )
-
- # var = get_variable_from_offset(offset)
- sympy_monomial *= sympy_var**count
-
- sympy_polynomial += value * sympy_monomial
-
- A[row, col] = sympy_polynomial
-
- return state, A
-
- return init_state_monad(func)
+) -> StateMonadMixin[ExpressionState, sympy.Expr | sympy.Matrix]:
+
+ def polymatrix_to_sympy(state: ExpressionState) -> sympy.Expr | sympy.Matrix:
+ # Convert to polymatrix
+ state, pm = expr.apply(state)
+
+ m = sympy.zeros(*pm.shape)
+ for entry, poly in pm.entries():
+ sympy_poly_terms = []
+ for monomial, coeff in poly.terms():
+ sympy_monomial = math.prod(
+ sympy.Symbol(state.get_name(variable.index)) ** variable.power
+ for variable in monomial)
+
+ if math.isclose(coeff, 1.):
+ # no need to add 1 in front
+ sympy_poly_terms.append(sympy_monomial)
+
+ else:
+ sympy_poly_terms.append(coeff * sympy_monomial)
+
+ m[*entry] = sum(sympy_poly_terms)
+
+ if math.prod(pm.shape) == 1:
+ # just return the expression
+ return m[0, 0]
+
+ return m
+ return init_state_monad(polymatrix_to_sympy)