diff options
-rw-r--r-- | polymatrix/expression/mixins/fromsympyexprmixin.py | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index 9b26f09..4728bbb 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -14,7 +14,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod def data(self) -> tuple[tuple[float]]: - pass + ... # overwrites abstract method of `ExpressionBaseMixin` def apply( @@ -22,16 +22,16 @@ class FromSympyExprMixin(ExpressionBaseMixin): state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - terms = {} + polynomials = {} for poly_row, col_data in enumerate(self.data): for poly_col, poly_data in enumerate(col_data): if isinstance(poly_data, (int, float)): if math.isclose(poly_data, 0): - terms_row_col = {} + polynomial = {} else: - terms_row_col = {tuple(): poly_data} + polynomial = {tuple(): poly_data} elif isinstance(poly_data, sympy.Expr): try: @@ -39,7 +39,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): except sympy.polys.polyerrors.GeneratorsNeeded: if not math.isclose(poly_data, 0): - terms[poly_row, poly_col] = {tuple(): poly_data} + polynomials[poly_row, poly_col] = {tuple(): poly_data} continue except ValueError: @@ -49,7 +49,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): state = state.register(key=symbol, n_param=1) # print(f'{symbol}: {state.n_param}') - terms_row_col = {} + polynomial = {} # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2) for value, monomial_count in zip(poly.coeffs(), poly.monoms()): @@ -66,21 +66,25 @@ class FromSympyExprMixin(ExpressionBaseMixin): monomial = tuple(gen_monomial()) - terms_row_col[monomial] = value + polynomial[monomial] = value - elif isinstance(poly_data, np.number): - terms_row_col = {tuple(): float(poly_data)} + # elif isinstance(poly_data, np.number): + # terms_row_col = {tuple(): float(poly_data)} - # elif isinstance(poly_data, ExpressionBaseMixin): - # pass + elif isinstance(poly_data, ExpressionBaseMixin): + state, instance = poly_data.apply(state) + + assert instance.shape == (1, 1) + + polynomial = instance.get_poly(0, 0) else: raise Exception(f'{poly_data=}, {type(poly_data)=}') - terms[poly_row, poly_col] = terms_row_col + polynomials[poly_row, poly_col] = polynomial poly_matrix = init_poly_matrix( - terms=terms, + terms=polynomials, shape=(len(self.data), len(self.data[0])), ) |