diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-02-09 10:33:54 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-02-09 10:33:54 +0100 |
commit | f9e0dd5ce540ccdd4019afda315b566ff8fe4499 (patch) | |
tree | 6ca2ae7b36f5abb819baf2999dc0a0bac640b336 | |
parent | accept polymatrix.Expression in from_sympy_expr() (diff) | |
download | polymatrix-f9e0dd5ce540ccdd4019afda315b566ff8fe4499.tar.gz polymatrix-f9e0dd5ce540ccdd4019afda315b566ff8fe4499.zip |
sympy.expand a sympy expression before converting it to a polymatrix
-rw-r--r-- | polymatrix/__init__.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/init/initfromsympyexpr.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/mixins/fromsympyexprmixin.py | 3 | ||||
-rw-r--r-- | polymatrix/expression/mixins/substituteexprmixin.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/utils/getvariableindices.py | 2 |
5 files changed, 12 insertions, 4 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index e6cf4a8..dff503b 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -8,6 +8,8 @@ import scipy.sparse import sympy from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.expressionmixin import ExpressionMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr @@ -437,7 +439,7 @@ def to_matrix_repr( if isinstance(expressions, Expression): expressions = (expressions,) - assert isinstance(variables, Expression), f'{variables=}' + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' def func(state: ExpressionState): diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py index f8dbd94..2ab6ded 100644 --- a/polymatrix/expression/init/initfromsympyexpr.py +++ b/polymatrix/expression/init/initfromsympyexpr.py @@ -43,8 +43,11 @@ def init_from_sympy_expr( case ExpressionBaseMixin(): return data + case sympy.Expr(): + data = ((sympy.expand(data),),) + case _: - if not isinstance(data, (float, int, sympy.Expr)): + if not isinstance(data, (float, int)): raise Exception(f'{data=}, {type(data)=}') data = ((data,),) diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index e2efcd4..9b26f09 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -71,6 +71,9 @@ class FromSympyExprMixin(ExpressionBaseMixin): elif isinstance(poly_data, np.number): terms_row_col = {tuple(): float(poly_data)} + # elif isinstance(poly_data, ExpressionBaseMixin): + # pass + else: raise Exception(f'{poly_data=}, {type(poly_data)=}') diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index 6c0beae..f69c870 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -42,7 +42,7 @@ class SubstituteExprMixin(ExpressionBaseMixin): state, substitution = expr.apply(state) - assert substitution.shape[1] == 1, f'{substitution=}' + assert substitution.shape[1] == 1, f'The following expression has to be a vector {expr=}' def gen_polynomials(): for row in range(substitution.shape[0]): diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 371894e..901c567 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -15,7 +15,7 @@ def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple for row in range(variable_polynomial.shape[0]): row_terms = variable_polynomial.get_poly(row, 0) - assert len(row_terms) == 1, f'{row_terms} contains more than one term' + assert len(row_terms) == 1, f'{row_terms} does not contain a single term' for monomial in row_terms.keys(): assert len(monomial) <= 1, f'{monomial=} contains more than one variable' |