From 1c761458fc117329bca0c5ec4eb9ced05c66945a Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Fri, 26 Aug 2022 15:06:51 +0200
Subject: 'from_sympy_expr' method works with different data types

---
 polymatrix/expression/init/initfromsympyexpr.py    | 11 +++-
 polymatrix/expression/init/initfromtermsexpr.py    | 16 ++++--
 polymatrix/expression/init/initsubstituteexpr.py   |  3 ++
 polymatrix/expression/mixins/fromsympyexprmixin.py | 62 +++++++++++++---------
 4 files changed, 64 insertions(+), 28 deletions(-)

diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py
index 08a6a04..bb37f1d 100644
--- a/polymatrix/expression/init/initfromsympyexpr.py
+++ b/polymatrix/expression/init/initfromsympyexpr.py
@@ -11,7 +11,16 @@ def init_from_sympy_expr(
 
 	match data:
 		case np.ndarray():
-			data = tuple(tuple(i for i in row) for row in data)
+			assert len(data.shape) <= 2
+
+			def gen_elements():
+				for row in data:
+					if isinstance(row, np.ndarray):
+						yield tuple(row)
+					else:
+						yield (row,)
+
+			data = tuple(gen_elements())
 
 		case sympy.Matrix():
 			data = tuple(tuple(i for i in data.row(row)) for row in range(data.rows))
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
index 8083eca..c2080c6 100644
--- a/polymatrix/expression/init/initfromtermsexpr.py
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -4,17 +4,27 @@ from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
 
 
 def init_from_terms_expr(
-		terms: typing.Union[tuple, PolyMatrixMixin],
-		shape: tuple[int, int] = None,
+	terms: typing.Union[tuple, PolyMatrixMixin],
+	shape: tuple[int, int] = None,
 ):
+
 	if isinstance(terms, PolyMatrixMixin):
 		shape = terms.shape
 		gen_terms = terms.gen_terms()
 
 	else:
 		assert shape is not None
-		gen_terms = terms
 
+		if isinstance(terms, tuple):
+			gen_terms = terms
+
+		elif isinstance(terms, dict):
+			gen_terms = terms.items()
+
+		else:
+			raise Exception(f'{terms=}')
+
+	# Expression needs to be hashable
 	terms_formatted = tuple((key, tuple(monomials.items())) for key, monomials in gen_terms)
 
 	return FromTermsExprImpl(
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index db2f2b8..403b169 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -13,6 +13,9 @@ def init_substitute_expr(
 	if substitutions is None:
 		assert isinstance(variables, tuple)
 
+		if len(variables) == 0:
+			return underlying
+
 		variables, substitutions = tuple(zip(*variables))
 
 	elif isinstance(substitutions, np.ndarray):
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index 8853022..8414a25 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -2,6 +2,7 @@
 import abc
 import math
 import sympy
+import numpy as np
 
 from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -26,39 +27,52 @@ class FromSympyExprMixin(ExpressionBaseMixin):
         for poly_row, col_data in enumerate(self.data):
             for poly_col, poly_data in enumerate(col_data):
 
-                try:
-                    poly = sympy.poly(poly_data)
+                if isinstance(poly_data, (int, float)):
+                    if math.isclose(poly_data, 0):
+                        terms_row_col = {}
+                    else:
+                        terms_row_col = {tuple(): poly_data}
 
-                except sympy.polys.polyerrors.GeneratorsNeeded:
-                    if not math.isclose(poly_data, 0):
-                        terms[poly_row, poly_col] = {tuple(): poly_data}
-                    continue
+                elif isinstance(poly_data, sympy.Expr):
+                    try:
+                        poly = sympy.poly(poly_data)
 
-                except ValueError:
-                    raise ValueError(f'{poly_data=}')
+                    except sympy.polys.polyerrors.GeneratorsNeeded:
+                        if not math.isclose(poly_data, 0):
+                            terms[poly_row, poly_col] = {tuple(): poly_data}
+                        continue
 
-                for symbol in poly.gens:
-                    state = state.register(key=symbol, n_param=1)
-                    # print(f'{symbol}: {state.n_param}')
+                    except ValueError:
+                        raise ValueError(f'{poly_data=}')
 
-                terms_row_col = {}
+                    for symbol in poly.gens:
+                        state = state.register(key=symbol, n_param=1)
+                        # print(f'{symbol}: {state.n_param}')
 
-                # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2)
-                for value, monomial_count in zip(poly.coeffs(), poly.monoms()):
+                    terms_row_col = {}
 
-                    if math.isclose(value, 0):
-                        continue
+                    # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2)
+                    for value, monomial_count in zip(poly.coeffs(), poly.monoms()):
+
+                        if math.isclose(value, 0):
+                            continue
+
+                        # m_cnt=(1, 0, 2) -> m=(0, 2, 2)
+                        def gen_monomial():
+                            for rel_idx, p in enumerate(monomial_count):
+                                if 0 < p:
+                                    idx, _ = state.offset_dict[poly.gens[rel_idx]]
+                                    yield idx, p
+
+                        monomial = tuple(gen_monomial())
 
-                    # m_cnt=(1, 0, 2) -> m=(0, 2, 2)
-                    def gen_monomial():
-                        for rel_idx, p in enumerate(monomial_count):
-                            if 0 < p:
-                                idx, _ = state.offset_dict[poly.gens[rel_idx]]
-                                yield idx, p
+                        terms_row_col[monomial] = value
 
-                    monomial = tuple(gen_monomial())
+                elif isinstance(poly_data, np.ndarray) and np.issubdtype(poly_data, np.number):
+                    terms_row_col = {tuple(): float(poly_data)}
 
-                    terms_row_col[monomial] = value
+                else:
+                    raise Exception(f'{poly_data=}, {type(poly_data)=}')
 
                 terms[poly_row, poly_col] = terms_row_col
 
-- 
cgit v1.2.1