summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/__init__.py49
-rw-r--r--polymatrix/expression/mixins/nsexprmixin.py24
2 files changed, 66 insertions, 7 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 94f84f9..ed3ba2e 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -35,7 +35,7 @@ def convert_args_to_expression(fn: Callable) -> Callable:
return wrapper
-def ones(shape: int | tuple[int, int] | Expression) -> Expression:
+def ones(shape: int | tuple[int | Expression, int | Expression] | Expression) -> Expression:
""" Make a matrix filled with ones """
if isinstance(shape, Expression):
expr = init_ns_expr(n=1, shape=shape.underlying)
@@ -44,7 +44,27 @@ def ones(shape: int | tuple[int, int] | Expression) -> Expression:
expr = init_ns_expr(n=1, shape=(shape, 1))
elif isinstance(shape, tuple):
- expr = init_ns_expr(n=1, shape=shape)
+ if isinstance(shape[0], Expression):
+ nrows = shape[0].underlying
+
+ elif isinstance(shape[0], int):
+ nrows = shape[0]
+
+ else:
+ raise TypeError("Number of rows must be an integer or an expression "
+ "that evaluates to an integer")
+
+ if isinstance(shape[1], Expression):
+ ncols = shape[1].underlying
+
+ elif isinstance(shape[1], int):
+ ncols = shape[1]
+
+ else:
+ raise TypeError("Number of columns must be an integer or an expression "
+ "that evaluates to an integer")
+
+ expr = init_ns_expr(n=1, shape=(nrows, ncols))
else:
# TODO: error messages
@@ -72,8 +92,7 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr
return init_expression(e)
-@convert_args_to_expression
-def zeros(shape: Expression):
+def zeros(shape: int | tuple[int | Expression, int | Expression] | Expression):
""" Make a matrix filled with zeros """
if isinstance(shape, Expression):
expr = init_ns_expr(n=0, shape=shape.underlying)
@@ -82,7 +101,27 @@ def zeros(shape: Expression):
expr = init_ns_expr(n=0, shape=(shape, 1))
elif isinstance(shape, tuple):
- expr = init_ns_expr(n=0, shape=shape)
+ if isinstance(shape[0], Expression):
+ nrows = shape[0].underlying
+
+ elif isinstance(shape[0], int):
+ nrows = shape[0]
+
+ else:
+ raise TypeError("Number of rows must be an integer or an expression "
+ "that evaluates to an integer")
+
+ if isinstance(shape[1], Expression):
+ ncols = shape[1].underlying
+
+ elif isinstance(shape[1], int):
+ ncols = shape[1]
+
+ else:
+ raise TypeError("Number of columns must be an integer or an expression "
+ "that evaluates to an integer")
+
+ expr = init_ns_expr(n=0, shape=(nrows, ncols))
else:
# TODO: error messages
diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py
index 6129093..ea95330 100644
--- a/polymatrix/expression/mixins/nsexprmixin.py
+++ b/polymatrix/expression/mixins/nsexprmixin.py
@@ -48,8 +48,28 @@ class NsExprMixin(ExpressionBaseMixin):
nrows = int(pm.at(0, 0).constant())
ncols = int(pm.at(1, 0).constant())
- else:
- nrows, ncols = self.shape
+ elif isinstance(self.shape, tuple):
+ if isinstance(self.shape[0], ExpressionBaseMixin):
+ state, nrows_pm = self.shape[0].apply(state)
+ nrows = int(nrows_pm.scalar().constant())
+
+ elif isinstance(self.shape[0], int):
+ nrows = self.shape[0]
+
+ else:
+ # TODO: error message
+ raise TypeError
+
+ if isinstance(self.shape[1], ExpressionBaseMixin):
+ state, ncols_pm = self.shape[1].apply(state)
+ ncols = int(ncols_pm.scalar().constant())
+
+ elif isinstance(self.shape[1], int):
+ ncols = self.shape[1]
+
+ else:
+ # TODO: error message
+ raise TypeError
return state, init_broadcast_poly_matrix(n, shape=(nrows, ncols))