summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/__init__.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index da1dce6..94f84f9 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -57,18 +57,20 @@ def ones(shape: int | tuple[int, int] | Expression) -> Expression:
def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None):
# Replicate range()'s behaviour
if stop is None and step is None:
- return init_arange_expr(start=None, stop=start_or_stop, step=None)
+ e = init_arange_expr(start=None, stop=start_or_stop.underlying, step=None)
elif stop is not None and step is None:
- return init_arange_expr(start=start_or_stop, stop=stop, step=None)
+ e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=None)
elif stop is not None and step is not None:
- return init_arange_expr(start=start_or_stop, stop=stop, step=step)
+ e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=step.underlying)
else:
# FIXME: error message missing
raise ValueError
+ return init_expression(e)
+
@convert_args_to_expression
def zeros(shape: Expression):