diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/init/initfromsympyexpr.py | 11 | ||||
-rw-r--r-- | polymatrix/expression/init/initfromtermsexpr.py | 16 | ||||
-rw-r--r-- | polymatrix/expression/init/initsubstituteexpr.py | 3 | ||||
-rw-r--r-- | polymatrix/expression/mixins/fromsympyexprmixin.py | 62 |
4 files changed, 64 insertions, 28 deletions
diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py index 08a6a04..bb37f1d 100644 --- a/polymatrix/expression/init/initfromsympyexpr.py +++ b/polymatrix/expression/init/initfromsympyexpr.py @@ -11,7 +11,16 @@ def init_from_sympy_expr( match data: case np.ndarray(): - data = tuple(tuple(i for i in row) for row in data) + assert len(data.shape) <= 2 + + def gen_elements(): + for row in data: + if isinstance(row, np.ndarray): + yield tuple(row) + else: + yield (row,) + + data = tuple(gen_elements()) case sympy.Matrix(): data = tuple(tuple(i for i in data.row(row)) for row in range(data.rows)) diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py index 8083eca..c2080c6 100644 --- a/polymatrix/expression/init/initfromtermsexpr.py +++ b/polymatrix/expression/init/initfromtermsexpr.py @@ -4,17 +4,27 @@ from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin def init_from_terms_expr( - terms: typing.Union[tuple, PolyMatrixMixin], - shape: tuple[int, int] = None, + terms: typing.Union[tuple, PolyMatrixMixin], + shape: tuple[int, int] = None, ): + if isinstance(terms, PolyMatrixMixin): shape = terms.shape gen_terms = terms.gen_terms() else: assert shape is not None - gen_terms = terms + if isinstance(terms, tuple): + gen_terms = terms + + elif isinstance(terms, dict): + gen_terms = terms.items() + + else: + raise Exception(f'{terms=}') + + # Expression needs to be hashable terms_formatted = tuple((key, tuple(monomials.items())) for key, monomials in gen_terms) return FromTermsExprImpl( diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py index db2f2b8..403b169 100644 --- a/polymatrix/expression/init/initsubstituteexpr.py +++ b/polymatrix/expression/init/initsubstituteexpr.py @@ -13,6 +13,9 @@ def init_substitute_expr( if substitutions is None: assert isinstance(variables, tuple) + if len(variables) == 0: + return underlying + variables, substitutions = tuple(zip(*variables)) elif isinstance(substitutions, np.ndarray): diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index 8853022..8414a25 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -2,6 +2,7 @@ import abc import math import sympy +import numpy as np from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -26,39 +27,52 @@ class FromSympyExprMixin(ExpressionBaseMixin): for poly_row, col_data in enumerate(self.data): for poly_col, poly_data in enumerate(col_data): - try: - poly = sympy.poly(poly_data) + if isinstance(poly_data, (int, float)): + if math.isclose(poly_data, 0): + terms_row_col = {} + else: + terms_row_col = {tuple(): poly_data} - except sympy.polys.polyerrors.GeneratorsNeeded: - if not math.isclose(poly_data, 0): - terms[poly_row, poly_col] = {tuple(): poly_data} - continue + elif isinstance(poly_data, sympy.Expr): + try: + poly = sympy.poly(poly_data) - except ValueError: - raise ValueError(f'{poly_data=}') + except sympy.polys.polyerrors.GeneratorsNeeded: + if not math.isclose(poly_data, 0): + terms[poly_row, poly_col] = {tuple(): poly_data} + continue - for symbol in poly.gens: - state = state.register(key=symbol, n_param=1) - # print(f'{symbol}: {state.n_param}') + except ValueError: + raise ValueError(f'{poly_data=}') - terms_row_col = {} + for symbol in poly.gens: + state = state.register(key=symbol, n_param=1) + # print(f'{symbol}: {state.n_param}') - # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2) - for value, monomial_count in zip(poly.coeffs(), poly.monoms()): + terms_row_col = {} - if math.isclose(value, 0): - continue + # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2) + for value, monomial_count in zip(poly.coeffs(), poly.monoms()): + + if math.isclose(value, 0): + continue + + # m_cnt=(1, 0, 2) -> m=(0, 2, 2) + def gen_monomial(): + for rel_idx, p in enumerate(monomial_count): + if 0 < p: + idx, _ = state.offset_dict[poly.gens[rel_idx]] + yield idx, p + + monomial = tuple(gen_monomial()) - # m_cnt=(1, 0, 2) -> m=(0, 2, 2) - def gen_monomial(): - for rel_idx, p in enumerate(monomial_count): - if 0 < p: - idx, _ = state.offset_dict[poly.gens[rel_idx]] - yield idx, p + terms_row_col[monomial] = value - monomial = tuple(gen_monomial()) + elif isinstance(poly_data, np.ndarray) and np.issubdtype(poly_data, np.number): + terms_row_col = {tuple(): float(poly_data)} - terms_row_col[monomial] = value + else: + raise Exception(f'{poly_data=}, {type(poly_data)=}') terms[poly_row, poly_col] = terms_row_col |