From d1886e25874399ad13d3a7b28ccb9e7a47629305 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 27 May 2024 14:59:53 +0200 Subject: Improve polymatrix.ones and zeros --- polymatrix/expression/__init__.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index e3fa464..eac0a07 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -34,10 +34,22 @@ def convert_args_to_expression(fn: Callable) -> Callable: return wrapper -@convert_args_to_expression -def ones(shape: Expression): +def ones(shape: int | tuple[int, int] | Expression) -> Expression: """ Make a matrix filled with ones """ - return init_expression(init_ns_expr(n=1, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=1, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=1, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=1, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) @convert_args_to_expression @@ -60,7 +72,20 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr @convert_args_to_expression def zeros(shape: Expression): """ Make a matrix filled with zeros """ - return init_expression(init_ns_expr(n=0, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=0, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=0, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=0, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) def v_stack(expressions: Iterable) -> Expression: -- cgit v1.2.1