summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-06-05 21:04:20 +0200
committerNao Pross <np@0hm.ch>2024-06-05 21:04:20 +0200
commitcb720e1b51b1609557c3a9a49e5881a3c8ecff30 (patch)
treefe526f0825cd9a71632eb0058bacee01e6897f22
parentFix bug from renaming (diff)
downloadpolymatrix-cb720e1b51b1609557c3a9a49e5881a3c8ecff30.tar.gz
polymatrix-cb720e1b51b1609557c3a9a49e5881a3c8ecff30.zip
Allow shape argument of variables to be (int, ExpressionMixin) etc.
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/from_.py10
-rw-r--r--polymatrix/expression/impl.py2
-rw-r--r--polymatrix/expression/init.py2
-rw-r--r--polymatrix/expression/mixins/variableexprmixin.py33
4 files changed, 36 insertions, 11 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 7198e0c..7b8c93c 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -78,13 +78,19 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre
return value_if_not_supported
-def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]:
+def from_names(
+ names: str,
+ shape: Expression | tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin = (1,1)
+) -> Iterable[VariableExpression]:
""" Construct one or multiple variables from comma separated a list of names. """
for name in names.split(","):
yield from_name(name.strip(), shape)
-def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> VariableExpression:
+def from_name(
+ name: str,
+ shape: Expression | tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin = (1,1)
+) -> VariableExpression:
""" Construct a variable from its names """
if isinstance(shape, Expression):
shape = shape.underlying
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index eccf99a..764cab8 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -510,7 +510,7 @@ class TruncateExprImpl(TruncateExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class VariableExprImpl(VariableExprMixin):
symbol: Symbol
- shape: tuple[int, int] | ExpressionBaseMixin
+ shape: tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin
def __str__(self):
return self.symbol
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 4b8d182..6ed517e 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -496,5 +496,5 @@ def init_truncate_expr(
)
-def init_variable_expr(sym: Symbol, shape: tuple[int, int] | ExpressionBaseMixin):
+def init_variable_expr(sym: Symbol, shape: tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin):
return polymatrix.expression.impl.VariableExprImpl(sym, shape)
diff --git a/polymatrix/expression/mixins/variableexprmixin.py b/polymatrix/expression/mixins/variableexprmixin.py
index 83bc66e..6e78be0 100644
--- a/polymatrix/expression/mixins/variableexprmixin.py
+++ b/polymatrix/expression/mixins/variableexprmixin.py
@@ -17,7 +17,7 @@ class VariableExprMixin(ExpressionBaseMixin):
@property
@abstractmethod
- def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
+ def shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin] | ExpressionBaseMixin:
""" Shape of the variable expression. """
@property
@@ -37,17 +37,36 @@ class VariableExprMixin(ExpressionBaseMixin):
nrows = int(shape_pm.at(0, 0).constant())
ncols = int(shape_pm.at(1, 0).constant())
- # Replace shape field with computed shape
- state = state.register(self.symbol, shape=(nrows, ncols))
-
elif isinstance(self.shape, tuple):
- nrows, ncols = self.shape # for for loop below
- state = state.register(self.symbol, self.shape)
+ if isinstance(self.shape[0], ExpressionBaseMixin):
+ state, nrows_pm = self.shape[0].apply(state)
+ nrows = int(nrows_pm.scalar().constant())
+
+ elif isinstance(self.shape[0], int):
+ nrows = self.shape[0]
+
+ else:
+ raise TypeError("Number of row in shape must be either an integer, "
+ "or an expression that evaluates to a scalar integer. "
+ f"But given value has type {type(self.shape[0])}")
+
+ if isinstance(self.shape[1], ExpressionBaseMixin):
+ state, ncols_pm = self.shape[1].apply(state)
+ ncols = int(ncols_pm.scalar().constant())
+
+ elif isinstance(self.shape[1], int):
+ ncols = self.shape[1]
+
+ else:
+ raise TypeError("Number of columns in shape must be either an integer, "
+ "or a expression that evaluates to a scalar integer. "
+ f"But given value has type {type(self.shape[0])}")
else:
- raise ValueError("Shape must be a tuple or expression that "
+ raise TypeError("Shape must be a tuple or expression that "
f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
+ state = state.register(self.symbol, shape=(nrows, ncols))
indices = state.get_indices(self.symbol)
p = PolyMatrixDict()