diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/__init__.py | 61 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 11 | ||||
-rw-r--r-- | polymatrix/expression/typing.py | 1 |
3 files changed, 21 insertions, 52 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index ed3ba2e..c7c365d 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -7,6 +7,7 @@ import polymatrix.expression.impl from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.expression import init_expression, Expression +from polymatrix.expression.from_ import from_any from polymatrix.expression.mixins.namedexprmixin import NamedExprMixin from polymatrix.expression.init import ( @@ -41,30 +42,14 @@ def ones(shape: int | tuple[int | Expression, int | Expression] | Expression) -> expr = init_ns_expr(n=1, shape=shape.underlying) elif isinstance(shape, int): - expr = init_ns_expr(n=1, shape=(shape, 1)) + expr = init_ns_expr(n=1, shape=v_stack((shape, 1))) elif isinstance(shape, tuple): - if isinstance(shape[0], Expression): - nrows = shape[0].underlying + nrows, ncols = shape + if (not isinstance(nrows, int)) or (not isinstance(ncols, int)): + shape = from_any(((nrows,), (ncols,))).underlying - elif isinstance(shape[0], int): - nrows = shape[0] - - else: - raise TypeError("Number of rows must be an integer or an expression " - "that evaluates to an integer") - - if isinstance(shape[1], Expression): - ncols = shape[1].underlying - - elif isinstance(shape[1], int): - ncols = shape[1] - - else: - raise TypeError("Number of columns must be an integer or an expression " - "that evaluates to an integer") - - expr = init_ns_expr(n=1, shape=(nrows, ncols)) + expr = init_ns_expr(n=1, shape=shape) else: # TODO: error messages @@ -98,30 +83,14 @@ def zeros(shape: int | tuple[int | Expression, int | Expression] | Expression): expr = init_ns_expr(n=0, shape=shape.underlying) elif isinstance(shape, int): - expr = init_ns_expr(n=0, shape=(shape, 1)) + expr = init_ns_expr(n=0, shape=v_stack((shape, 1))) elif isinstance(shape, tuple): - if isinstance(shape[0], Expression): - nrows = shape[0].underlying - - elif isinstance(shape[0], int): - nrows = shape[0] + nrows, ncols = shape + if (not isinstance(nrows, int)) or (not isinstance(ncols, int)): + shape = from_any(((nrows,), (ncols,))).underlying - else: - raise TypeError("Number of rows must be an integer or an expression " - "that evaluates to an integer") - - if isinstance(shape[1], Expression): - ncols = shape[1].underlying - - elif isinstance(shape[1], int): - ncols = shape[1] - - else: - raise TypeError("Number of columns must be an integer or an expression " - "that evaluates to an integer") - - expr = init_ns_expr(n=0, shape=(nrows, ncols)) + expr = init_ns_expr(n=0, shape=shape) else: # TODO: error messages @@ -134,14 +103,18 @@ def v_stack(expressions: Iterable) -> Expression: """ Vertically stack expressions """ # arrange blocks vertically and concatanete blocks = tuple((from_.from_any(e).underlying,) for e in expressions) - return init_expression(init_concatenate_expr(blocks)) + u = init_expression(init_concatenate_expr(blocks)) + names = ", ".join(str(b[0]) for b in blocks) + return give_name(u, f"vstack({names})") def h_stack(expressions: Iterable) -> Expression: """ Horizontally stack expressions """ # arrange horizontally blocks = (tuple(from_.from_any(e).underlying for e in expressions),) - return init_expression(init_concatenate_expr(blocks)) + u = init_expression(init_concatenate_expr(blocks)) + names = ", ".join(str(b) for b in blocks[0]) + return give_name(u, f"hstack({names})") def concatenate(expressions: Iterable[Iterable]): diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 03091c7..3be44e5 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -83,7 +83,7 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre # row vector tuple[...] - elif all(not isinstance(row, tuple) for v in value): + elif all(not isinstance(row, tuple) for row in value): wrapped_rows: list[Expression] = [] for row in value: wrapped = from_any_or(row, None) @@ -119,13 +119,8 @@ def from_name( elif isinstance(shape, tuple): nrows, ncols = shape - if isinstance(nrows, Expression): - nrows = nrows.underlying - - if isinstance(ncols, Expression): - ncols = ncols.underlying - - shape = (nrows, ncols) + if (not isinstance(nrows, int)) or (not isinstance(ncols, int)): + shape = from_any(((nrows,), (ncols,))).underlying return init_variable_expression(init.init_variable_expr(Symbol(name), shape)) diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py index 87ade26..23d7f66 100644 --- a/polymatrix/expression/typing.py +++ b/polymatrix/expression/typing.py @@ -11,6 +11,7 @@ FromSupportedTypes = ( str | NDArray | sympy.Matrix | sympy.Expr + | tuple[...] | tuple[tuple[...]] | ExpressionBaseMixin | StateMonad |