diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/from_.py | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 7b8c93c..ce35a83 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -62,18 +62,40 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre if len(value) < 1: return value_if_not_supported - if isinstance(value[0], tuple): - if isinstance(value[0], int | float): - return from_numbers(value) - - elif isinstance(value[0], sympy.Expr): - return from_sympy(value) - - elif isinstance(value[0], int | float): - return from_numbers(value) - - elif isinstance(value[0], sympy.Expr): - return from_sympy((value,)) + from polymatrix.expression import v_stack, h_stack + + # matrix given as tuple[tuple[...]], row major order + if all(isinstance(row, tuple) for row in value): + wrapped_rows: list[list[Expression]] = [] + for row in value: + if len(row) != len(value[0]): + return value_if_not_supported + + wrapped_row: list[Expression] = [] + for col in row: + wrapped = from_any_or(col, None) + if wrapped is None: + return value_if_not_supported + wrapped_row.append(wrapped) + + wrapped_rows.append(wrapped_row) + return v_stack(h_stack(row) for row in wrapped_rows) + + + # row vector tuple[...] + elif all(not isinstance(row, tuple) for v in value): + wrapped_rows: list[Expression] = [] + for row in value: + wrapped = from_any_or(row, None) + if wrapped is None: + return value_if_not_supported + wrapped_rows.append(wrapped) + + return h_stack(wrapped_rows) + + # invalid + else: + value_if_not_supported return value_if_not_supported |