diff options
-rw-r--r-- | polymatrix/expression/__init__.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index da1dce6..94f84f9 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -57,18 +57,20 @@ def ones(shape: int | tuple[int, int] | Expression) -> Expression: def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None): # Replicate range()'s behaviour if stop is None and step is None: - return init_arange_expr(start=None, stop=start_or_stop, step=None) + e = init_arange_expr(start=None, stop=start_or_stop.underlying, step=None) elif stop is not None and step is None: - return init_arange_expr(start=start_or_stop, stop=stop, step=None) + e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=None) elif stop is not None and step is not None: - return init_arange_expr(start=start_or_stop, stop=stop, step=step) + e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=step.underlying) else: # FIXME: error message missing raise ValueError + return init_expression(e) + @convert_args_to_expression def zeros(shape: Expression): |