diff options
-rw-r--r-- | polymatrix/expression/__init__.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index e3fa464..eac0a07 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -34,10 +34,22 @@ def convert_args_to_expression(fn: Callable) -> Callable: return wrapper -@convert_args_to_expression -def ones(shape: Expression): +def ones(shape: int | tuple[int, int] | Expression) -> Expression: """ Make a matrix filled with ones """ - return init_expression(init_ns_expr(n=1, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=1, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=1, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=1, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) @convert_args_to_expression @@ -60,7 +72,20 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr @convert_args_to_expression def zeros(shape: Expression): """ Make a matrix filled with zeros """ - return init_expression(init_ns_expr(n=0, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=0, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=0, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=0, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) def v_stack(expressions: Iterable) -> Expression: |