summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/from_.py7
-rw-r--r--polymatrix/expression/impl.py2
-rw-r--r--polymatrix/expression/init.py5
-rw-r--r--polymatrix/expression/mixins/variablemixin.py43
4 files changed, 45 insertions, 12 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 88181a8..8dbc055 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -37,12 +37,15 @@ def from_statemonad(monad: StateMonad) -> Expression:
return init_expression(polymatrix.expression.init.init_from_statemonad(monad))
-def from_names(names: str, shape: tuple[int, int] = (1,1)) -> Iterable[VariableExpression]:
+def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]:
""" Construct one or multiple variables from comma separated a list of names. """
for name in names.split(","):
yield from_name(name.strip(), shape)
-def from_name(name: str, shape: tuple[int, int] = (1,1)) -> VariableExpression:
+def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> VariableExpression:
""" Construct a variable from its names """
+ if isinstance(shape, Expression):
+ shape = shape.underlying
+
return init_variable_expression(polymatrix.expression.init.init_variable_expr(name, shape))
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index f4c4df4..177d44e 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -433,7 +433,7 @@ class TruncateExprImpl(TruncateExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class VariableImpl(VariableMixin):
name: str
- shape: tuple[int, int]
+ shape: tuple[int, int] | ExpressionBaseMixin
def __str__(self):
return self.name
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 60fa580..c109156 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -483,5 +483,8 @@ def init_truncate_expr(
)
-def init_variable_expr(name: str, shape: tuple[int, int]):
+def init_variable_expr(
+ name: str,
+ shape: tuple[int, int] | ExpressionBaseMixin
+):
return polymatrix.expression.impl.VariableImpl(name, shape)
diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py
index 8bc8b94..355b6b1 100644
--- a/polymatrix/expression/mixins/variablemixin.py
+++ b/polymatrix/expression/mixins/variablemixin.py
@@ -1,12 +1,14 @@
from __future__ import annotations
-import typing
import itertools
+from abc import abstractmethod
+from dataclasses import replace
from typing_extensions import override
from polymatrix.expressionstate import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
from polymatrix.variable import Variable
@@ -15,18 +17,43 @@ class VariableMixin(ExpressionBaseMixin, Variable):
""" Underlying object for VariableExpression """
@override
+ @property
+ @abstractmethod
+ def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
+ """ Shape of the variable expression. """
+
+ @override
def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
- state = state.register(self)
- indices = state.get_indices(self)
- p = PolyMatrixDict()
+ if isinstance(self.shape, ExpressionBaseMixin):
+ state, shape_pm = self.shape.apply(state)
+ if shape_pm.shape != (2, 1):
+ raise ValueError("If shape is an expression it must evaluate to a 2d row vector, "
+ f"but here it has shape {shape_pm.shape}")
+
+ # FIXME: should check that they are actually integers
+ nrows = int(shape_pm.at(0, 0).constant())
+ ncols = int(shape_pm.at(1, 0).constant())
+
+ # Replace shape field with computed shape
+ v = replace(self, shape=(nrows, ncols))
+ state = state.register(v)
+ indices = state.get_indices(v)
+
+ elif isinstance(self.shape, tuple):
+ nrows, ncols = self.shape
+ state = state.register(self)
+ indices = state.get_indices(self)
+
+ else:
+ raise ValueError("Shape must be a tuple or expression that "
+ f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
- rows, cols = self.shape
- for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices):
+ p = PolyMatrixDict()
+ for (row, col), index in zip(itertools.product(range(nrows), range(ncols)), indices):
p[row, col] = PolyDict({
# Create monomial with variable to the first power
# with coefficient of one
MonomialIndex((VariableIndex(index, power=1),)): 1.
})
- return state, init_poly_matrix(p, self.shape)
-
+ return state, init_poly_matrix(p, shape=(nrows, ncols))