summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/fromsympyexprmixin.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index a0003fb..301dce3 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -36,15 +36,19 @@ class FromSympyExprMixin(ExpressionBaseMixin):
@override
def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrix]:
+
# Unpack if it is a sympy matrix
- data = sympy.expand(self.data)
- if isinstance(data, sympy.Matrix):
- data = ((entry for entry in data.row(r))
- for r in range(data.rows))
+ if isinstance(self.data, sympy.Matrix):
+ data = cast(tuple[tuple[sympy.Expr]],
+ ((entry for entry in self.data.row(r))
+ for r in range(self.data.rows)))
# Pack if is raw (scalar) expression
- if isinstance(data, sympy.Expr):
- data = ((data,),)
+ elif isinstance(self.data, sympy.Expr):
+ data = ((self.data,),)
+
+ else:
+ data = self.data
# Convert to polymatrix
polymatrix = PolyMatrixDict.empty()
@@ -65,18 +69,22 @@ class FromSympyExprMixin(ExpressionBaseMixin):
raise ValueError(f"Cannot convert sympy expression {entry} "
"into a polynomial, are you sure it is a polynomial?") from e
- # Convert sympy variables to our variables
+ # Convert sympy variables to our variables, i.e VariableMixin
+ # FIXME: This import cannot be above because of circular imports
+ # not sure how I am supposed to fit this in correctly into the
+ # dataclass + mixin pattern structure
from polymatrix.expression.init import init_variable_expr
sympy_to_var = {
sympy_idx: init_variable_expr(var.name)
for sympy_idx, var in enumerate(sympy_poly.gens)
}
+ # Construct poynomial
poly = PolyDict.empty()
for coeff, monom in zip(sympy_poly.coeffs(), sympy_poly.monoms()):
# sympy monomial is stored as multi-index, eg. for a
# multivariate polynomial with three variables (x, y, z)
- # the index is x*y**2 = (1, 2, 0)
+ # the index is x * y**2 = (1, 2, 0)
m: list[VariableIndex] = []
for i, exponent in enumerate(monom):
var = sympy_to_var[i]