diff options
author | Nao Pross <np@0hm.ch> | 2024-05-24 18:39:13 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-24 18:47:24 +0200 |
commit | 8301c6c16fe1117634f814d23f37496819e59408 (patch) | |
tree | a954a1255121b25b0c5263336390a335e7b20a7b | |
parent | Create ShapeExprMixin (diff) | |
download | polymatrix-8301c6c16fe1117634f814d23f37496819e59408.tar.gz polymatrix-8301c6c16fe1117634f814d23f37496819e59408.zip |
Allow shape of VariableExpr to be an Expression
Does this, combined with ShapeExpr make ParametrieExpr obsolete?
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/from_.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/mixins/variablemixin.py | 43 |
4 files changed, 45 insertions, 12 deletions
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 88181a8..8dbc055 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -37,12 +37,15 @@ def from_statemonad(monad: StateMonad) -> Expression: return init_expression(polymatrix.expression.init.init_from_statemonad(monad)) -def from_names(names: str, shape: tuple[int, int] = (1,1)) -> Iterable[VariableExpression]: +def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]: """ Construct one or multiple variables from comma separated a list of names. """ for name in names.split(","): yield from_name(name.strip(), shape) -def from_name(name: str, shape: tuple[int, int] = (1,1)) -> VariableExpression: +def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> VariableExpression: """ Construct a variable from its names """ + if isinstance(shape, Expression): + shape = shape.underlying + return init_variable_expression(polymatrix.expression.init.init_variable_expr(name, shape)) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index f4c4df4..177d44e 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -433,7 +433,7 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) class VariableImpl(VariableMixin): name: str - shape: tuple[int, int] + shape: tuple[int, int] | ExpressionBaseMixin def __str__(self): return self.name diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 60fa580..c109156 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -483,5 +483,8 @@ def init_truncate_expr( ) -def init_variable_expr(name: str, shape: tuple[int, int]): +def init_variable_expr( + name: str, + shape: tuple[int, int] | ExpressionBaseMixin +): return polymatrix.expression.impl.VariableImpl(name, shape) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py index 8bc8b94..355b6b1 100644 --- a/polymatrix/expression/mixins/variablemixin.py +++ b/polymatrix/expression/mixins/variablemixin.py @@ -1,12 +1,14 @@ from __future__ import annotations -import typing import itertools +from abc import abstractmethod +from dataclasses import replace from typing_extensions import override from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.variable import Variable @@ -15,18 +17,43 @@ class VariableMixin(ExpressionBaseMixin, Variable): """ Underlying object for VariableExpression """ @override + @property + @abstractmethod + def shape(self) -> tuple[int, int] | ExpressionBaseMixin: + """ Shape of the variable expression. """ + + @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: - state = state.register(self) - indices = state.get_indices(self) - p = PolyMatrixDict() + if isinstance(self.shape, ExpressionBaseMixin): + state, shape_pm = self.shape.apply(state) + if shape_pm.shape != (2, 1): + raise ValueError("If shape is an expression it must evaluate to a 2d row vector, " + f"but here it has shape {shape_pm.shape}") + + # FIXME: should check that they are actually integers + nrows = int(shape_pm.at(0, 0).constant()) + ncols = int(shape_pm.at(1, 0).constant()) + + # Replace shape field with computed shape + v = replace(self, shape=(nrows, ncols)) + state = state.register(v) + indices = state.get_indices(v) + + elif isinstance(self.shape, tuple): + nrows, ncols = self.shape + state = state.register(self) + indices = state.get_indices(self) + + else: + raise ValueError("Shape must be a tuple or expression that " + f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") - rows, cols = self.shape - for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices): + p = PolyMatrixDict() + for (row, col), index in zip(itertools.product(range(nrows), range(ncols)), indices): p[row, col] = PolyDict({ # Create monomial with variable to the first power # with coefficient of one MonomialIndex((VariableIndex(index, power=1),)): 1. }) - return state, init_poly_matrix(p, self.shape) - + return state, init_poly_matrix(p, shape=(nrows, ncols)) |