diff options
author | Nao Pross <np@0hm.ch> | 2024-05-27 14:59:53 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-27 14:59:53 +0200 |
commit | d1886e25874399ad13d3a7b28ccb9e7a47629305 (patch) | |
tree | 1f789802c8604ef9e796357187b0e0a11bd72077 | |
parent | Add missing check for degrees in CombinationsExpr (diff) | |
download | polymatrix-d1886e25874399ad13d3a7b28ccb9e7a47629305.tar.gz polymatrix-d1886e25874399ad13d3a7b28ccb9e7a47629305.zip |
Improve polymatrix.ones and zeros
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/__init__.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index e3fa464..eac0a07 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -34,10 +34,22 @@ def convert_args_to_expression(fn: Callable) -> Callable: return wrapper -@convert_args_to_expression -def ones(shape: Expression): +def ones(shape: int | tuple[int, int] | Expression) -> Expression: """ Make a matrix filled with ones """ - return init_expression(init_ns_expr(n=1, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=1, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=1, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=1, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) @convert_args_to_expression @@ -60,7 +72,20 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr @convert_args_to_expression def zeros(shape: Expression): """ Make a matrix filled with zeros """ - return init_expression(init_ns_expr(n=0, shape=shape.underlying)) + if isinstance(shape, Expression): + expr = init_ns_expr(n=0, shape=shape.underlying) + + elif isinstance(shape, int): + expr = init_ns_expr(n=0, shape=(shape, 1)) + + elif isinstance(shape, tuple): + expr = init_ns_expr(n=0, shape=shape) + + else: + # TODO: error messages + raise TypeError + + return init_expression(expr) def v_stack(expressions: Iterable) -> Expression: |