diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/mixins/fromsympyexprmixin.py | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index a0003fb..301dce3 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -36,15 +36,19 @@ class FromSympyExprMixin(ExpressionBaseMixin): @override def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrix]: + # Unpack if it is a sympy matrix - data = sympy.expand(self.data) - if isinstance(data, sympy.Matrix): - data = ((entry for entry in data.row(r)) - for r in range(data.rows)) + if isinstance(self.data, sympy.Matrix): + data = cast(tuple[tuple[sympy.Expr]], + ((entry for entry in self.data.row(r)) + for r in range(self.data.rows))) # Pack if is raw (scalar) expression - if isinstance(data, sympy.Expr): - data = ((data,),) + elif isinstance(self.data, sympy.Expr): + data = ((self.data,),) + + else: + data = self.data # Convert to polymatrix polymatrix = PolyMatrixDict.empty() @@ -65,18 +69,22 @@ class FromSympyExprMixin(ExpressionBaseMixin): raise ValueError(f"Cannot convert sympy expression {entry} " "into a polynomial, are you sure it is a polynomial?") from e - # Convert sympy variables to our variables + # Convert sympy variables to our variables, i.e VariableMixin + # FIXME: This import cannot be above because of circular imports + # not sure how I am supposed to fit this in correctly into the + # dataclass + mixin pattern structure from polymatrix.expression.init import init_variable_expr sympy_to_var = { sympy_idx: init_variable_expr(var.name) for sympy_idx, var in enumerate(sympy_poly.gens) } + # Construct poynomial poly = PolyDict.empty() for coeff, monom in zip(sympy_poly.coeffs(), sympy_poly.monoms()): # sympy monomial is stored as multi-index, eg. for a # multivariate polynomial with three variables (x, y, z) - # the index is x*y**2 = (1, 2, 0) + # the index is x * y**2 = (1, 2, 0) m: list[VariableIndex] = [] for i, exponent in enumerate(monom): var = sympy_to_var[i] |