diff options
author | Nao Pross <np@0hm.ch> | 2024-06-07 11:20:46 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-06-07 11:20:46 +0200 |
commit | 06458caa1d4723a19c69f0da80765dfcb36eca07 (patch) | |
tree | 55f337f4d9cf196a8873eb58878e52b1d86fab3c | |
parent | Fix odd behaviour of MaxExpr (diff) | |
download | polymatrix-06458caa1d4723a19c69f0da80765dfcb36eca07.tar.gz polymatrix-06458caa1d4723a19c69f0da80765dfcb36eca07.zip |
Fix bug when working with new shapes (e, 5) where e is an expression
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/__init__.py | 49 | ||||
-rw-r--r-- | polymatrix/expression/mixins/nsexprmixin.py | 24 |
2 files changed, 66 insertions, 7 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 94f84f9..ed3ba2e 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -35,7 +35,7 @@ def convert_args_to_expression(fn: Callable) -> Callable: return wrapper -def ones(shape: int | tuple[int, int] | Expression) -> Expression: +def ones(shape: int | tuple[int | Expression, int | Expression] | Expression) -> Expression: """ Make a matrix filled with ones """ if isinstance(shape, Expression): expr = init_ns_expr(n=1, shape=shape.underlying) @@ -44,7 +44,27 @@ def ones(shape: int | tuple[int, int] | Expression) -> Expression: expr = init_ns_expr(n=1, shape=(shape, 1)) elif isinstance(shape, tuple): - expr = init_ns_expr(n=1, shape=shape) + if isinstance(shape[0], Expression): + nrows = shape[0].underlying + + elif isinstance(shape[0], int): + nrows = shape[0] + + else: + raise TypeError("Number of rows must be an integer or an expression " + "that evaluates to an integer") + + if isinstance(shape[1], Expression): + ncols = shape[1].underlying + + elif isinstance(shape[1], int): + ncols = shape[1] + + else: + raise TypeError("Number of columns must be an integer or an expression " + "that evaluates to an integer") + + expr = init_ns_expr(n=1, shape=(nrows, ncols)) else: # TODO: error messages @@ -72,8 +92,7 @@ def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expr return init_expression(e) -@convert_args_to_expression -def zeros(shape: Expression): +def zeros(shape: int | tuple[int | Expression, int | Expression] | Expression): """ Make a matrix filled with zeros """ if isinstance(shape, Expression): expr = init_ns_expr(n=0, shape=shape.underlying) @@ -82,7 +101,27 @@ def zeros(shape: Expression): expr = init_ns_expr(n=0, shape=(shape, 1)) elif isinstance(shape, tuple): - expr = init_ns_expr(n=0, shape=shape) + if isinstance(shape[0], Expression): + nrows = shape[0].underlying + + elif isinstance(shape[0], int): + nrows = shape[0] + + else: + raise TypeError("Number of rows must be an integer or an expression " + "that evaluates to an integer") + + if isinstance(shape[1], Expression): + ncols = shape[1].underlying + + elif isinstance(shape[1], int): + ncols = shape[1] + + else: + raise TypeError("Number of columns must be an integer or an expression " + "that evaluates to an integer") + + expr = init_ns_expr(n=0, shape=(nrows, ncols)) else: # TODO: error messages diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py index 6129093..ea95330 100644 --- a/polymatrix/expression/mixins/nsexprmixin.py +++ b/polymatrix/expression/mixins/nsexprmixin.py @@ -48,8 +48,28 @@ class NsExprMixin(ExpressionBaseMixin): nrows = int(pm.at(0, 0).constant()) ncols = int(pm.at(1, 0).constant()) - else: - nrows, ncols = self.shape + elif isinstance(self.shape, tuple): + if isinstance(self.shape[0], ExpressionBaseMixin): + state, nrows_pm = self.shape[0].apply(state) + nrows = int(nrows_pm.scalar().constant()) + + elif isinstance(self.shape[0], int): + nrows = self.shape[0] + + else: + # TODO: error message + raise TypeError + + if isinstance(self.shape[1], ExpressionBaseMixin): + state, ncols_pm = self.shape[1].apply(state) + ncols = int(ncols_pm.scalar().constant()) + + elif isinstance(self.shape[1], int): + ncols = self.shape[1] + + else: + # TODO: error message + raise TypeError return state, init_broadcast_poly_matrix(n, shape=(nrows, ncols)) |