summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-06-06 00:02:32 +0200
committerNao Pross <np@0hm.ch>2024-06-06 00:02:32 +0200
commit1251467d2f25007fd51546f417225a3649acaff0 (patch)
treeba9da0cc04f6427603fb9fe693b65d910847aa9b
parentAllow shape argument of variables to be (int, ExpressionMixin) etc. (diff)
downloadpolymatrix-1251467d2f25007fd51546f417225a3649acaff0.tar.gz
polymatrix-1251467d2f25007fd51546f417225a3649acaff0.zip
Allow construction from mixed tuples in from_any
In other words from_any((1, x, y)) with x and y of any type in FromSupportedTypes (not necessarily the same) works.
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/from_.py46
1 files changed, 34 insertions, 12 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 7b8c93c..ce35a83 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -62,18 +62,40 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre
if len(value) < 1:
return value_if_not_supported
- if isinstance(value[0], tuple):
- if isinstance(value[0], int | float):
- return from_numbers(value)
-
- elif isinstance(value[0], sympy.Expr):
- return from_sympy(value)
-
- elif isinstance(value[0], int | float):
- return from_numbers(value)
-
- elif isinstance(value[0], sympy.Expr):
- return from_sympy((value,))
+ from polymatrix.expression import v_stack, h_stack
+
+ # matrix given as tuple[tuple[...]], row major order
+ if all(isinstance(row, tuple) for row in value):
+ wrapped_rows: list[list[Expression]] = []
+ for row in value:
+ if len(row) != len(value[0]):
+ return value_if_not_supported
+
+ wrapped_row: list[Expression] = []
+ for col in row:
+ wrapped = from_any_or(col, None)
+ if wrapped is None:
+ return value_if_not_supported
+ wrapped_row.append(wrapped)
+
+ wrapped_rows.append(wrapped_row)
+ return v_stack(h_stack(row) for row in wrapped_rows)
+
+
+ # row vector tuple[...]
+ elif all(not isinstance(row, tuple) for v in value):
+ wrapped_rows: list[Expression] = []
+ for row in value:
+ wrapped = from_any_or(row, None)
+ if wrapped is None:
+ return value_if_not_supported
+ wrapped_rows.append(wrapped)
+
+ return h_stack(wrapped_rows)
+
+ # invalid
+ else:
+ value_if_not_supported
return value_if_not_supported