From 038e6f27b8bd0af431a8fe0a2f31255990b2dae1 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 11 May 2024 18:10:22 +0200 Subject: Delete ParametrizeMatrixExprMixin, update ParametrizeExprMixin to work with any shape --- polymatrix/expression/impl.py | 37 +++----------- polymatrix/expression/init.py | 10 ---- .../expression/mixins/parametrizeexprmixin.py | 31 ++++++------ .../mixins/parametrizematrixexprmixin.py | 59 ---------------------- polymatrix/expression/to.py | 4 -- 5 files changed, 22 insertions(+), 119 deletions(-) delete mode 100644 polymatrix/expression/mixins/parametrizematrixexprmixin.py diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index e6b021f..19fd4dd 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -21,9 +21,7 @@ from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin from polymatrix.expression.mixins.evalexprmixin import EvalExprMixin from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin from polymatrix.expression.mixins.filterexprmixin import FilterExprMixin -from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import ( - FromSymmetricMatrixExprMixin, -) +from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin from polymatrix.expression.mixins.fromnumbersexprmixin import FromNumbersExprMixin from polymatrix.expression.mixins.fromnumpyexprmixin import FromNumpyExprMixin from polymatrix.expression.mixins.fromstatemonad import FromStateMonadMixin @@ -33,47 +31,32 @@ from polymatrix.expression.mixins.fromtermsexprmixin import ( PolynomialMatrixTupledData, ) from polymatrix.expression.mixins.getitemexprmixin import GetItemExprMixin -from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import ( - HalfNewtonPolytopeExprMixin, -) +from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import HalfNewtonPolytopeExprMixin from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin -from polymatrix.expression.mixins.linearmonomialsexprmixin import ( - LinearMonomialsExprMixin, -) +from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin -from polymatrix.expression.mixins.parametrizematrixexprmixin import ( - ParametrizeMatrixExprMixin, -) from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin -from polymatrix.expression.mixins.quadraticmonomialsexprmixin import ( - QuadraticMonomialsExprMixin, -) +from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMonomialsExprMixin from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin from polymatrix.expression.mixins.substituteexprmixin import SubstituteExprMixin -from polymatrix.expression.mixins.subtractmonomialsexprmixin import ( - SubtractMonomialsExprMixin, -) +from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin from polymatrix.expression.mixins.sumexprmixin import SumExprMixin from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin -from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ( - ToSymmetricMatrixExprMixin, -) +from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricMatrixExprMixin from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin -from polymatrix.expression.mixins.tosortedvariablesmixin import ( - ToSortedVariablesExprMixin, -) +from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin @dataclassabc.dataclassabc(frozen=True) @@ -309,12 +292,6 @@ class ParametrizeExprImpl(ParametrizeExprMixin): ) -@dataclassabc.dataclassabc(frozen=True) -class ParametrizeMatrixExprImpl(ParametrizeMatrixExprMixin): - underlying: ExpressionBaseMixin - name: str - - @dataclassabc.dataclassabc(frozen=True) class PowerExprImpl(PowerExprMixin): left: ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index bb3da9a..7eab999 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -328,16 +328,6 @@ def init_parametrize_expr( ) -def init_parametrize_matrix_expr( - underlying: ExpressionBaseMixin, - name: str, -): - return polymatrix.expression.impl.ParametrizeMatrixExprImpl( - underlying=underlying, - name=name, - ) - - def init_power_expr( left: ExpressionBaseMixin, right: ExpressionBaseMixin, diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index 5762480..ca49411 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -1,5 +1,7 @@ -import abc -import dataclasses +from __future__ import annotations + +from abc import abstractmethod +from itertools import product from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins import ExpressionStateMixin @@ -11,9 +13,9 @@ from polymatrix.variable.init import init_variable class ParametrizeExprMixin(ExpressionBaseMixin): r""" - Given a column vector (or an expression that evaluates to a column vector) - :math:`x \in \mathbf{R}^n` and a name, for instance :math:`u`, create a new - vector of variables called :math:`u` of the same size as :math:`x`. + Given matrix or vector (or an expression) :math:`x` and a name, for + instance :math:`u`, create a new variable :math:`u` of the same shape as + :math:`x`. This is useful if you want to create coefficients, for example: @@ -26,39 +28,36 @@ class ParametrizeExprMixin(ExpressionBaseMixin): """ @property - @abc.abstractclassmethod + @abstractmethod def underlying(self) -> ExpressionBaseMixin: ... @property - @abc.abstractclassmethod + @abstractmethod def name(self) -> str: ... # overwrites the abstract method of `ExpressionBaseMixin` def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) - nrows, ncols = underlying.shape - if ncols != 1: - raise ValueError("Parametrize works only with column vectors") # FIXME: not sure this behaviour is intuitive, discuss if v := state.get_variable_from_name_or(self.name, if_not_present=False): start, end = state.offset_dict[self.name] - if nrows != (end - start): + if (nrows * ncols) != (end - start): raise ValueError("Cannot parametrize {self.underlying} with variable {v} " "found in state object, because its shape {(nrow, ncols)} " "does not match ({self.underlying.shape}). ") else: - v = init_variable(self.name, shape=(nrows, 1)) + v = init_variable(self.name, shape=(nrows, ncols)) state = state.register(v) - p = PolyMatrixDict({ - MatrixIndex(row, 0): PolyDict({ + p = PolyMatrixDict.empty() + indices = state.get_indices(v) + for (row, col), index in zip(product(range(nrows), range(ncols)), indices): + p[row, col] = PolyDict({ MonomialIndex((VariableIndex(index, 1),)): 1 }) - for row, index in enumerate(state.get_indices(v)) - }) return state, init_poly_matrix(p, underlying.shape) diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py deleted file mode 100644 index 175a5c3..0000000 --- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py +++ /dev/null @@ -1,59 +0,0 @@ -import abc -import dataclasses - -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expressionstate.mixins import ExpressionStateMixin -from polymatrix.polymatrix.mixins import PolyMatrixMixin - - -# remove? -class ParametrizeMatrixExprMixin(ExpressionBaseMixin): - @property - @abc.abstractclassmethod - def underlying(self) -> ExpressionBaseMixin: ... - - @property - @abc.abstractclassmethod - def name(self) -> str: ... - - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - # cache polymatrix to not re-parametrize at every apply call - if self in state.cache: - return state, state.cache[self] - - state, underlying = self.underlying.apply(state) - - assert underlying.shape[1] == 1 - - poly_matrix_data = {} - var_index = 0 - - for row in range(underlying.shape[0]): - for _ in range(row, underlying.shape[0]): - poly_matrix_data[var_index, 0] = { - ((state.n_param + var_index, 1),): 1.0 - } - - var_index += 1 - - state = state.register( - key=self, - n_param=var_index, - ) - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=(var_index, 1), - ) - - state = dataclasses.replace( - state, - cache=state.cache | {self: poly_matrix}, - ) - - return state, poly_matrix diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index b1de626..c24b1d2 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -3,10 +3,6 @@ import sympy import numpy as np from polymatrix.expression.expression import Expression -from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin -from polymatrix.expression.mixins.parametrizematrixexprmixin import ( - ParametrizeMatrixExprMixin, -) from polymatrix.expressionstate.abc import ExpressionState from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -- cgit v1.2.1