diff options
-rw-r--r-- | polymatrix/expression/to.py | 83 |
1 files changed, 30 insertions, 53 deletions
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index 0cea1d9..2a43c0e 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -1,3 +1,4 @@ +import math import sympy import numpy as np @@ -46,56 +47,32 @@ def to_constant( def to_sympy( expr: Expression, -) -> StateMonadMixin[ExpressionState, sympy.Expr]: - - # TODO: fix this function with new expression state - raise NotImplementedError("Broken because FromTupleExpr was removed") - - def func(state: ExpressionState): - state, underlying = expr.apply(state) - - A = np.zeros(underlying.shape, dtype=object) - - for (row, col), polynomial in underlying.gen_data(): - sympy_polynomial = 0 - - for monomial, value in polynomial.items(): - sympy_monomial = 1 - - for offset, count in monomial: - variable = state.get_key_from_offset(offset) - # def get_variable_from_offset(offset: int): - # for variable, (start, end) in state.offset_dict.items(): - # if start <= offset < end: - # assert end - start == 1, f'{start=}, {end=}, {variable=}' - - if isinstance(variable, sympy.core.symbol.Symbol): - variable_name = variable.name - elif isinstance( - variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin) - ): - variable_name = variable.name - elif isinstance(variable, str): - variable_name = variable - else: - raise Exception(f"{variable=}") - - start, end = state.offset_dict[variable] - - if end - start == 1: - sympy_var = sympy.Symbol(variable_name) - else: - sympy_var = sympy.Symbol( - f"{variable_name}_{offset - start + 1}" - ) - - # var = get_variable_from_offset(offset) - sympy_monomial *= sympy_var**count - - sympy_polynomial += value * sympy_monomial - - A[row, col] = sympy_polynomial - - return state, A - - return init_state_monad(func) +) -> StateMonadMixin[ExpressionState, sympy.Expr | sympy.Matrix]: + + def polymatrix_to_sympy(state: ExpressionState) -> sympy.Expr | sympy.Matrix: + # Convert to polymatrix + state, pm = expr.apply(state) + + m = sympy.zeros(*pm.shape) + for entry, poly in pm.entries(): + sympy_poly_terms = [] + for monomial, coeff in poly.terms(): + sympy_monomial = math.prod( + sympy.Symbol(state.get_name(variable.index)) ** variable.power + for variable in monomial) + + if math.isclose(coeff, 1.): + # no need to add 1 in front + sympy_poly_terms.append(sympy_monomial) + + else: + sympy_poly_terms.append(coeff * sympy_monomial) + + m[*entry] = sum(sympy_poly_terms) + + if math.prod(pm.shape) == 1: + # just return the expression + return m[0, 0] + + return m + return init_state_monad(polymatrix_to_sympy) |