diff options
author | Nao Pross <np@0hm.ch> | 2024-06-06 00:02:32 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-06-06 00:02:32 +0200 |
commit | 1251467d2f25007fd51546f417225a3649acaff0 (patch) | |
tree | ba9da0cc04f6427603fb9fe693b65d910847aa9b | |
parent | Allow shape argument of variables to be (int, ExpressionMixin) etc. (diff) | |
download | polymatrix-1251467d2f25007fd51546f417225a3649acaff0.tar.gz polymatrix-1251467d2f25007fd51546f417225a3649acaff0.zip |
Allow construction from mixed tuples in from_any
In other words from_any((1, x, y)) with x and y of any type in
FromSupportedTypes (not necessarily the same) works.
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/from_.py | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 7b8c93c..ce35a83 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -62,18 +62,40 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre if len(value) < 1: return value_if_not_supported - if isinstance(value[0], tuple): - if isinstance(value[0], int | float): - return from_numbers(value) - - elif isinstance(value[0], sympy.Expr): - return from_sympy(value) - - elif isinstance(value[0], int | float): - return from_numbers(value) - - elif isinstance(value[0], sympy.Expr): - return from_sympy((value,)) + from polymatrix.expression import v_stack, h_stack + + # matrix given as tuple[tuple[...]], row major order + if all(isinstance(row, tuple) for row in value): + wrapped_rows: list[list[Expression]] = [] + for row in value: + if len(row) != len(value[0]): + return value_if_not_supported + + wrapped_row: list[Expression] = [] + for col in row: + wrapped = from_any_or(col, None) + if wrapped is None: + return value_if_not_supported + wrapped_row.append(wrapped) + + wrapped_rows.append(wrapped_row) + return v_stack(h_stack(row) for row in wrapped_rows) + + + # row vector tuple[...] + elif all(not isinstance(row, tuple) for v in value): + wrapped_rows: list[Expression] = [] + for row in value: + wrapped = from_any_or(row, None) + if wrapped is None: + return value_if_not_supported + wrapped_rows.append(wrapped) + + return h_stack(wrapped_rows) + + # invalid + else: + value_if_not_supported return value_if_not_supported |