summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-27 14:59:53 +0200
committerNao Pross <np@0hm.ch>2024-05-27 14:59:53 +0200
commitd1886e25874399ad13d3a7b28ccb9e7a47629305 (patch)
tree1f789802c8604ef9e796357187b0e0a11bd72077
parentAdd missing check for degrees in CombinationsExpr (diff)
downloadpolymatrix-d1886e25874399ad13d3a7b28ccb9e7a47629305.tar.gz
polymatrix-d1886e25874399ad13d3a7b28ccb9e7a47629305.zip
Improve polymatrix.ones and zeros
-rw-r--r--polymatrix/expression/__init__.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index e3fa464..eac0a07 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -34,10 +34,22 @@ def convert_args_to_expression(fn: Callable) -> Callable:
return wrapper
-@convert_args_to_expression
-def ones(shape: Expression):
+def ones(shape: int | tuple[int, int] | Expression) -> Expression:
""" Make a matrix filled with ones """
- return init_expression(init_ns_expr(n=1, shape=shape.underlying))
+ if isinstance(shape, Expression):
+ expr = init_ns_expr(n=1, shape=shape.underlying)
+
+ elif isinstance(shape, int):
+ expr = init_ns_expr(n=1, shape=(shape, 1))
+
+ elif isinstance(shape, tuple):
+ expr = init_ns_expr(n=1, shape=shape)
+
+ else:
+ # TODO: error messages
+ raise TypeError
+
+ return init_expression(expr)
@convert_args_to_expression
@@ -60,7 +72,20 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr
@convert_args_to_expression
def zeros(shape: Expression):
""" Make a matrix filled with zeros """
- return init_expression(init_ns_expr(n=0, shape=shape.underlying))
+ if isinstance(shape, Expression):
+ expr = init_ns_expr(n=0, shape=shape.underlying)
+
+ elif isinstance(shape, int):
+ expr = init_ns_expr(n=0, shape=(shape, 1))
+
+ elif isinstance(shape, tuple):
+ expr = init_ns_expr(n=0, shape=shape)
+
+ else:
+ # TODO: error messages
+ raise TypeError
+
+ return init_expression(expr)
def v_stack(expressions: Iterable) -> Expression: