summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-06-05 00:50:51 +0200
committerNao Pross <np@0hm.ch>2024-06-05 00:50:51 +0200
commit6446cade4fee2bd2697c784330795511bf3be132 (patch)
tree892c9629971441538dca5daa25291616f4aff2d7
parentSeparate symbols from variables to avoid shape problems (diff)
downloadpolymatrix-6446cade4fee2bd2697c784330795511bf3be132.tar.gz
polymatrix-6446cade4fee2bd2697c784330795511bf3be132.zip
Fix bug in arange()
-rw-r--r--polymatrix/expression/__init__.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index da1dce6..94f84f9 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -57,18 +57,20 @@ def ones(shape: int | tuple[int, int] | Expression) -> Expression:
def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None):
# Replicate range()'s behaviour
if stop is None and step is None:
- return init_arange_expr(start=None, stop=start_or_stop, step=None)
+ e = init_arange_expr(start=None, stop=start_or_stop.underlying, step=None)
elif stop is not None and step is None:
- return init_arange_expr(start=start_or_stop, stop=stop, step=None)
+ e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=None)
elif stop is not None and step is not None:
- return init_arange_expr(start=start_or_stop, stop=stop, step=step)
+ e = init_arange_expr(start=start_or_stop.underlying, stop=stop.underlying, step=step.underlying)
else:
# FIXME: error message missing
raise ValueError
+ return init_expression(e)
+
@convert_args_to_expression
def zeros(shape: Expression):