diff options
author | Nao Pross <np@0hm.ch> | 2024-06-05 21:04:20 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-06-05 21:04:20 +0200 |
commit | cb720e1b51b1609557c3a9a49e5881a3c8ecff30 (patch) | |
tree | fe526f0825cd9a71632eb0058bacee01e6897f22 | |
parent | Fix bug from renaming (diff) | |
download | polymatrix-cb720e1b51b1609557c3a9a49e5881a3c8ecff30.tar.gz polymatrix-cb720e1b51b1609557c3a9a49e5881a3c8ecff30.zip |
Allow shape argument of variables to be (int, ExpressionMixin) etc.
-rw-r--r-- | polymatrix/expression/from_.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/mixins/variableexprmixin.py | 33 |
4 files changed, 36 insertions, 11 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 7198e0c..7b8c93c 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -78,13 +78,19 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre return value_if_not_supported -def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]: +def from_names( + names: str, + shape: Expression | tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin = (1,1) +) -> Iterable[VariableExpression]: """ Construct one or multiple variables from comma separated a list of names. """ for name in names.split(","): yield from_name(name.strip(), shape) -def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> VariableExpression: +def from_name( + name: str, + shape: Expression | tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin = (1,1) +) -> VariableExpression: """ Construct a variable from its names """ if isinstance(shape, Expression): shape = shape.underlying diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index eccf99a..764cab8 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -510,7 +510,7 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) class VariableExprImpl(VariableExprMixin): symbol: Symbol - shape: tuple[int, int] | ExpressionBaseMixin + shape: tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin def __str__(self): return self.symbol diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 4b8d182..6ed517e 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -496,5 +496,5 @@ def init_truncate_expr( ) -def init_variable_expr(sym: Symbol, shape: tuple[int, int] | ExpressionBaseMixin): +def init_variable_expr(sym: Symbol, shape: tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin): return polymatrix.expression.impl.VariableExprImpl(sym, shape) diff --git a/polymatrix/expression/mixins/variableexprmixin.py b/polymatrix/expression/mixins/variableexprmixin.py index 83bc66e..6e78be0 100644 --- a/polymatrix/expression/mixins/variableexprmixin.py +++ b/polymatrix/expression/mixins/variableexprmixin.py @@ -17,7 +17,7 @@ class VariableExprMixin(ExpressionBaseMixin): @property @abstractmethod - def shape(self) -> tuple[int, int] | ExpressionBaseMixin: + def shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin: """ Shape of the variable expression. """ @property @@ -37,17 +37,36 @@ class VariableExprMixin(ExpressionBaseMixin): nrows = int(shape_pm.at(0, 0).constant()) ncols = int(shape_pm.at(1, 0).constant()) - # Replace shape field with computed shape - state = state.register(self.symbol, shape=(nrows, ncols)) - elif isinstance(self.shape, tuple): - nrows, ncols = self.shape # for for loop below - state = state.register(self.symbol, self.shape) + if isinstance(self.shape[0], ExpressionBaseMixin): + state, nrows_pm = self.shape[0].apply(state) + nrows = int(nrows_pm.scalar().constant()) + + elif isinstance(self.shape[0], int): + nrows = self.shape[0] + + else: + raise TypeError("Number of row in shape must be either an integer, " + "or an expression that evaluates to a scalar integer. " + f"But given value has type {type(self.shape[0])}") + + if isinstance(self.shape[1], ExpressionBaseMixin): + state, ncols_pm = self.shape[1].apply(state) + ncols = int(ncols_pm.scalar().constant()) + + elif isinstance(self.shape[1], int): + ncols = self.shape[1] + + else: + raise TypeError("Number of columns in shape must be either an integer, " + "or a expression that evaluates to a scalar integer. " + f"But given value has type {type(self.shape[0])}") else: - raise ValueError("Shape must be a tuple or expression that " + raise TypeError("Shape must be a tuple or expression that " f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") + state = state.register(self.symbol, shape=(nrows, ncols)) indices = state.get_indices(self.symbol) p = PolyMatrixDict() |