summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/__init__.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index e3fa464..eac0a07 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -34,10 +34,22 @@ def convert_args_to_expression(fn: Callable) -> Callable:
return wrapper
-@convert_args_to_expression
-def ones(shape: Expression):
+def ones(shape: int | tuple[int, int] | Expression) -> Expression:
""" Make a matrix filled with ones """
- return init_expression(init_ns_expr(n=1, shape=shape.underlying))
+ if isinstance(shape, Expression):
+ expr = init_ns_expr(n=1, shape=shape.underlying)
+
+ elif isinstance(shape, int):
+ expr = init_ns_expr(n=1, shape=(shape, 1))
+
+ elif isinstance(shape, tuple):
+ expr = init_ns_expr(n=1, shape=shape)
+
+ else:
+ # TODO: error messages
+ raise TypeError
+
+ return init_expression(expr)
@convert_args_to_expression
@@ -60,7 +72,20 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr
@convert_args_to_expression
def zeros(shape: Expression):
""" Make a matrix filled with zeros """
- return init_expression(init_ns_expr(n=0, shape=shape.underlying))
+ if isinstance(shape, Expression):
+ expr = init_ns_expr(n=0, shape=shape.underlying)
+
+ elif isinstance(shape, int):
+ expr = init_ns_expr(n=0, shape=(shape, 1))
+
+ elif isinstance(shape, tuple):
+ expr = init_ns_expr(n=0, shape=shape)
+
+ else:
+ # TODO: error messages
+ raise TypeError
+
+ return init_expression(expr)
def v_stack(expressions: Iterable) -> Expression: