summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/init/initfromsympyexpr.py11
-rw-r--r--polymatrix/expression/init/initfromtermsexpr.py16
-rw-r--r--polymatrix/expression/init/initsubstituteexpr.py3
-rw-r--r--polymatrix/expression/mixins/fromsympyexprmixin.py62
4 files changed, 64 insertions, 28 deletions
diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py
index 08a6a04..bb37f1d 100644
--- a/polymatrix/expression/init/initfromsympyexpr.py
+++ b/polymatrix/expression/init/initfromsympyexpr.py
@@ -11,7 +11,16 @@ def init_from_sympy_expr(
match data:
case np.ndarray():
- data = tuple(tuple(i for i in row) for row in data)
+ assert len(data.shape) <= 2
+
+ def gen_elements():
+ for row in data:
+ if isinstance(row, np.ndarray):
+ yield tuple(row)
+ else:
+ yield (row,)
+
+ data = tuple(gen_elements())
case sympy.Matrix():
data = tuple(tuple(i for i in data.row(row)) for row in range(data.rows))
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
index 8083eca..c2080c6 100644
--- a/polymatrix/expression/init/initfromtermsexpr.py
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -4,17 +4,27 @@ from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
def init_from_terms_expr(
- terms: typing.Union[tuple, PolyMatrixMixin],
- shape: tuple[int, int] = None,
+ terms: typing.Union[tuple, PolyMatrixMixin],
+ shape: tuple[int, int] = None,
):
+
if isinstance(terms, PolyMatrixMixin):
shape = terms.shape
gen_terms = terms.gen_terms()
else:
assert shape is not None
- gen_terms = terms
+ if isinstance(terms, tuple):
+ gen_terms = terms
+
+ elif isinstance(terms, dict):
+ gen_terms = terms.items()
+
+ else:
+ raise Exception(f'{terms=}')
+
+ # Expression needs to be hashable
terms_formatted = tuple((key, tuple(monomials.items())) for key, monomials in gen_terms)
return FromTermsExprImpl(
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index db2f2b8..403b169 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -13,6 +13,9 @@ def init_substitute_expr(
if substitutions is None:
assert isinstance(variables, tuple)
+ if len(variables) == 0:
+ return underlying
+
variables, substitutions = tuple(zip(*variables))
elif isinstance(substitutions, np.ndarray):
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index 8853022..8414a25 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -2,6 +2,7 @@
import abc
import math
import sympy
+import numpy as np
from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -26,39 +27,52 @@ class FromSympyExprMixin(ExpressionBaseMixin):
for poly_row, col_data in enumerate(self.data):
for poly_col, poly_data in enumerate(col_data):
- try:
- poly = sympy.poly(poly_data)
+ if isinstance(poly_data, (int, float)):
+ if math.isclose(poly_data, 0):
+ terms_row_col = {}
+ else:
+ terms_row_col = {tuple(): poly_data}
- except sympy.polys.polyerrors.GeneratorsNeeded:
- if not math.isclose(poly_data, 0):
- terms[poly_row, poly_col] = {tuple(): poly_data}
- continue
+ elif isinstance(poly_data, sympy.Expr):
+ try:
+ poly = sympy.poly(poly_data)
- except ValueError:
- raise ValueError(f'{poly_data=}')
+ except sympy.polys.polyerrors.GeneratorsNeeded:
+ if not math.isclose(poly_data, 0):
+ terms[poly_row, poly_col] = {tuple(): poly_data}
+ continue
- for symbol in poly.gens:
- state = state.register(key=symbol, n_param=1)
- # print(f'{symbol}: {state.n_param}')
+ except ValueError:
+ raise ValueError(f'{poly_data=}')
- terms_row_col = {}
+ for symbol in poly.gens:
+ state = state.register(key=symbol, n_param=1)
+ # print(f'{symbol}: {state.n_param}')
- # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2)
- for value, monomial_count in zip(poly.coeffs(), poly.monoms()):
+ terms_row_col = {}
- if math.isclose(value, 0):
- continue
+ # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2)
+ for value, monomial_count in zip(poly.coeffs(), poly.monoms()):
+
+ if math.isclose(value, 0):
+ continue
+
+ # m_cnt=(1, 0, 2) -> m=(0, 2, 2)
+ def gen_monomial():
+ for rel_idx, p in enumerate(monomial_count):
+ if 0 < p:
+ idx, _ = state.offset_dict[poly.gens[rel_idx]]
+ yield idx, p
+
+ monomial = tuple(gen_monomial())
- # m_cnt=(1, 0, 2) -> m=(0, 2, 2)
- def gen_monomial():
- for rel_idx, p in enumerate(monomial_count):
- if 0 < p:
- idx, _ = state.offset_dict[poly.gens[rel_idx]]
- yield idx, p
+ terms_row_col[monomial] = value
- monomial = tuple(gen_monomial())
+ elif isinstance(poly_data, np.ndarray) and np.issubdtype(poly_data, np.number):
+ terms_row_col = {tuple(): float(poly_data)}
- terms_row_col[monomial] = value
+ else:
+ raise Exception(f'{poly_data=}, {type(poly_data)=}')
terms[poly_row, poly_col] = terms_row_col