From fa2948d3f1ce1d712e323b33d527ba27c9665a99 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 27 May 2024 13:06:38 +0200 Subject: Expose polymatrix.ones and polymatrix.zeros to user --- polymatrix/__init__.py | 5 +++++ polymatrix/expression/__init__.py | 5 +++-- polymatrix/expression/impl.py | 3 +++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index a235661..d0f53ee 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -18,6 +18,8 @@ from polymatrix.expression.from_ import ( from polymatrix.expression import ( Expression as internal_Expression, arange as internal_arange, + ones as internal_ones, + zeros as internal_zeros, v_stack as internal_v_stack, h_stack as internal_h_stack, product as internal_product, @@ -42,6 +44,9 @@ init_expression_state = internal_init_expression_state make_state = init_expression_state arange = internal_arange +ones = internal_ones +zeros = internal_zeros + v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 56bc83a..e3fa464 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -37,7 +37,8 @@ def convert_args_to_expression(fn: Callable) -> Callable: @convert_args_to_expression def ones(shape: Expression): """ Make a matrix filled with ones """ - return init_expression(init_ns_expr(n=1., shape=shape.underlying)) + return init_expression(init_ns_expr(n=1, shape=shape.underlying)) + @convert_args_to_expression def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None): @@ -59,7 +60,7 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr @convert_args_to_expression def zeros(shape: Expression): """ Make a matrix filled with zeros """ - return init_expression(init_ns_expr(n=1., shape=shape.underlying)) + return init_expression(init_ns_expr(n=0, shape=shape.underlying)) def v_stack(expressions: Iterable) -> Expression: diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index ac15aef..8eb1857 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -344,6 +344,9 @@ class NsExprImpl(NsExprMixin): n: int | float | ExpressionBaseMixin shape: tuple[int, int] | ExpressionBaseMixin + def __str__(self): + return f"({self.n}s)" + @dataclassabc.dataclassabc(frozen=True, repr=False) class ParametrizeExprImpl(ParametrizeExprMixin): -- cgit v1.2.1