From 8301c6c16fe1117634f814d23f37496819e59408 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Fri, 24 May 2024 18:39:13 +0200
Subject: Allow shape of VariableExpr to be an Expression

Does this, combined with ShapeExpr make ParametrieExpr obsolete?
---
 polymatrix/expression/from_.py                |  7 +++--
 polymatrix/expression/impl.py                 |  2 +-
 polymatrix/expression/init.py                 |  5 +++-
 polymatrix/expression/mixins/variablemixin.py | 43 ++++++++++++++++++++++-----
 4 files changed, 45 insertions(+), 12 deletions(-)

diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 88181a8..8dbc055 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -37,12 +37,15 @@ def from_statemonad(monad: StateMonad) -> Expression:
     return init_expression(polymatrix.expression.init.init_from_statemonad(monad))
 
 
-def from_names(names: str, shape: tuple[int, int] = (1,1)) -> Iterable[VariableExpression]:
+def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]:
     """ Construct one or multiple variables from comma separated a list of names. """
     for name in names.split(","):
         yield from_name(name.strip(), shape)
 
 
-def from_name(name: str, shape: tuple[int, int] = (1,1)) -> VariableExpression:
+def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> VariableExpression:
     """ Construct a variable from its names """
+    if isinstance(shape, Expression):
+        shape = shape.underlying
+
     return init_variable_expression(polymatrix.expression.init.init_variable_expr(name, shape))
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index f4c4df4..177d44e 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -433,7 +433,7 @@ class TruncateExprImpl(TruncateExprMixin):
 @dataclassabc.dataclassabc(frozen=True)
 class VariableImpl(VariableMixin):
     name: str
-    shape: tuple[int, int]
+    shape: tuple[int, int] | ExpressionBaseMixin
 
     def __str__(self):
         return self.name
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 60fa580..c109156 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -483,5 +483,8 @@ def init_truncate_expr(
     )
 
 
-def init_variable_expr(name: str, shape: tuple[int, int]):
+def init_variable_expr(
+        name: str,
+        shape: tuple[int, int] | ExpressionBaseMixin
+):
     return polymatrix.expression.impl.VariableImpl(name, shape)
diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py
index 8bc8b94..355b6b1 100644
--- a/polymatrix/expression/mixins/variablemixin.py
+++ b/polymatrix/expression/mixins/variablemixin.py
@@ -1,12 +1,14 @@
 from __future__ import annotations
 
-import typing
 import itertools
+from abc import  abstractmethod
+from dataclasses import replace
 from typing_extensions import override
 
 from polymatrix.expressionstate import ExpressionState
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
 from polymatrix.variable import Variable
 
@@ -14,19 +16,44 @@ from polymatrix.variable import Variable
 class VariableMixin(ExpressionBaseMixin, Variable):
     """ Underlying object for VariableExpression """
 
+    @override
+    @property
+    @abstractmethod
+    def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
+        """ Shape of the variable expression. """
+
     @override
     def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
-        state = state.register(self) 
-        indices = state.get_indices(self)
-        p = PolyMatrixDict()
+        if isinstance(self.shape, ExpressionBaseMixin):
+            state, shape_pm = self.shape.apply(state)
+            if shape_pm.shape != (2, 1):
+                raise ValueError("If shape is an expression it must evaluate to a 2d row vector, "
+                                 f"but here it has shape {shape_pm.shape}")
+
+            # FIXME: should check that they are actually integers
+            nrows = int(shape_pm.at(0, 0).constant())
+            ncols = int(shape_pm.at(1, 0).constant())
+
+            # Replace shape field with computed shape
+            v = replace(self, shape=(nrows, ncols))
+            state = state.register(v) 
+            indices = state.get_indices(v)
+
+        elif isinstance(self.shape, tuple):
+            nrows, ncols = self.shape
+            state = state.register(self) 
+            indices = state.get_indices(self)
+
+        else:
+            raise ValueError("Shape must be a tuple or expression that "
+                f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}")
 
-        rows, cols = self.shape
-        for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices):
+        p = PolyMatrixDict()
+        for (row, col), index in zip(itertools.product(range(nrows), range(ncols)), indices):
             p[row, col] = PolyDict({
                 # Create monomial with variable to the first power
                 # with coefficient of one
                 MonomialIndex((VariableIndex(index, power=1),)): 1.
             })
 
-        return state, init_poly_matrix(p, self.shape)
-
+        return state, init_poly_matrix(p, shape=(nrows, ncols))
-- 
cgit v1.2.1