summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-02-09 10:33:54 +0100
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-02-09 10:33:54 +0100
commitf9e0dd5ce540ccdd4019afda315b566ff8fe4499 (patch)
tree6ca2ae7b36f5abb819baf2999dc0a0bac640b336
parentaccept polymatrix.Expression in from_sympy_expr() (diff)
downloadpolymatrix-f9e0dd5ce540ccdd4019afda315b566ff8fe4499.tar.gz
polymatrix-f9e0dd5ce540ccdd4019afda315b566ff8fe4499.zip
sympy.expand a sympy expression before converting it to a polymatrix
-rw-r--r--polymatrix/__init__.py4
-rw-r--r--polymatrix/expression/init/initfromsympyexpr.py5
-rw-r--r--polymatrix/expression/mixins/fromsympyexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py2
-rw-r--r--polymatrix/expression/utils/getvariableindices.py2
5 files changed, 12 insertions, 4 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index e6cf4a8..dff503b 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -8,6 +8,8 @@ import scipy.sparse
import sympy
from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
@@ -437,7 +439,7 @@ def to_matrix_repr(
if isinstance(expressions, Expression):
expressions = (expressions,)
- assert isinstance(variables, Expression), f'{variables=}'
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
def func(state: ExpressionState):
diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py
index f8dbd94..2ab6ded 100644
--- a/polymatrix/expression/init/initfromsympyexpr.py
+++ b/polymatrix/expression/init/initfromsympyexpr.py
@@ -43,8 +43,11 @@ def init_from_sympy_expr(
case ExpressionBaseMixin():
return data
+ case sympy.Expr():
+ data = ((sympy.expand(data),),)
+
case _:
- if not isinstance(data, (float, int, sympy.Expr)):
+ if not isinstance(data, (float, int)):
raise Exception(f'{data=}, {type(data)=}')
data = ((data,),)
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index e2efcd4..9b26f09 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -71,6 +71,9 @@ class FromSympyExprMixin(ExpressionBaseMixin):
elif isinstance(poly_data, np.number):
terms_row_col = {tuple(): float(poly_data)}
+ # elif isinstance(poly_data, ExpressionBaseMixin):
+ # pass
+
else:
raise Exception(f'{poly_data=}, {type(poly_data)=}')
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index 6c0beae..f69c870 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -42,7 +42,7 @@ class SubstituteExprMixin(ExpressionBaseMixin):
state, substitution = expr.apply(state)
- assert substitution.shape[1] == 1, f'{substitution=}'
+ assert substitution.shape[1] == 1, f'The following expression has to be a vector {expr=}'
def gen_polynomials():
for row in range(substitution.shape[0]):
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 371894e..901c567 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -15,7 +15,7 @@ def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple
for row in range(variable_polynomial.shape[0]):
row_terms = variable_polynomial.get_poly(row, 0)
- assert len(row_terms) == 1, f'{row_terms} contains more than one term'
+ assert len(row_terms) == 1, f'{row_terms} does not contain a single term'
for monomial in row_terms.keys():
assert len(monomial) <= 1, f'{monomial=} contains more than one variable'