summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-11 18:10:22 +0200
committerNao Pross <np@0hm.ch>2024-05-11 18:56:54 +0200
commit038e6f27b8bd0af431a8fe0a2f31255990b2dae1 (patch)
tree38f82176e3fa08db27525a06e2ed9e21e6e7bc73
parentMark get_variable_indices_from_variable as deprecated (diff)
downloadpolymatrix-038e6f27b8bd0af431a8fe0a2f31255990b2dae1.tar.gz
polymatrix-038e6f27b8bd0af431a8fe0a2f31255990b2dae1.zip
Delete ParametrizeMatrixExprMixin, update ParametrizeExprMixin to work with any shape
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/impl.py37
-rw-r--r--polymatrix/expression/init.py10
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py31
-rw-r--r--polymatrix/expression/mixins/parametrizematrixexprmixin.py59
-rw-r--r--polymatrix/expression/to.py4
5 files changed, 22 insertions, 119 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index e6b021f..19fd4dd 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -21,9 +21,7 @@ from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin
from polymatrix.expression.mixins.evalexprmixin import EvalExprMixin
from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin
from polymatrix.expression.mixins.filterexprmixin import FilterExprMixin
-from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import (
- FromSymmetricMatrixExprMixin,
-)
+from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin
from polymatrix.expression.mixins.fromnumbersexprmixin import FromNumbersExprMixin
from polymatrix.expression.mixins.fromnumpyexprmixin import FromNumpyExprMixin
from polymatrix.expression.mixins.fromstatemonad import FromStateMonadMixin
@@ -33,47 +31,32 @@ from polymatrix.expression.mixins.fromtermsexprmixin import (
PolynomialMatrixTupledData,
)
from polymatrix.expression.mixins.getitemexprmixin import GetItemExprMixin
-from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import (
- HalfNewtonPolytopeExprMixin,
-)
+from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import HalfNewtonPolytopeExprMixin
from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin
-from polymatrix.expression.mixins.linearmonomialsexprmixin import (
- LinearMonomialsExprMixin,
-)
+from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
-from polymatrix.expression.mixins.parametrizematrixexprmixin import (
- ParametrizeMatrixExprMixin,
-)
from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin
from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin
-from polymatrix.expression.mixins.quadraticmonomialsexprmixin import (
- QuadraticMonomialsExprMixin,
-)
+from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMonomialsExprMixin
from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin
from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin
from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin
from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
from polymatrix.expression.mixins.substituteexprmixin import SubstituteExprMixin
-from polymatrix.expression.mixins.subtractmonomialsexprmixin import (
- SubtractMonomialsExprMixin,
-)
+from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin
from polymatrix.expression.mixins.sumexprmixin import SumExprMixin
from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin
from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin
-from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import (
- ToSymmetricMatrixExprMixin,
-)
+from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricMatrixExprMixin
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
-from polymatrix.expression.mixins.tosortedvariablesmixin import (
- ToSortedVariablesExprMixin,
-)
+from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin
@dataclassabc.dataclassabc(frozen=True)
@@ -310,12 +293,6 @@ class ParametrizeExprImpl(ParametrizeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
-class ParametrizeMatrixExprImpl(ParametrizeMatrixExprMixin):
- underlying: ExpressionBaseMixin
- name: str
-
-
-@dataclassabc.dataclassabc(frozen=True)
class PowerExprImpl(PowerExprMixin):
left: ExpressionBaseMixin
right: ExpressionBaseMixin | int | float
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index bb3da9a..7eab999 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -328,16 +328,6 @@ def init_parametrize_expr(
)
-def init_parametrize_matrix_expr(
- underlying: ExpressionBaseMixin,
- name: str,
-):
- return polymatrix.expression.impl.ParametrizeMatrixExprImpl(
- underlying=underlying,
- name=name,
- )
-
-
def init_power_expr(
left: ExpressionBaseMixin,
right: ExpressionBaseMixin,
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index 5762480..ca49411 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -1,5 +1,7 @@
-import abc
-import dataclasses
+from __future__ import annotations
+
+from abc import abstractmethod
+from itertools import product
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
@@ -11,9 +13,9 @@ from polymatrix.variable.init import init_variable
class ParametrizeExprMixin(ExpressionBaseMixin):
r"""
- Given a column vector (or an expression that evaluates to a column vector)
- :math:`x \in \mathbf{R}^n` and a name, for instance :math:`u`, create a new
- vector of variables called :math:`u` of the same size as :math:`x`.
+ Given matrix or vector (or an expression) :math:`x` and a name, for
+ instance :math:`u`, create a new variable :math:`u` of the same shape as
+ :math:`x`.
This is useful if you want to create coefficients, for example:
@@ -26,39 +28,36 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
"""
@property
- @abc.abstractclassmethod
+ @abstractmethod
def underlying(self) -> ExpressionBaseMixin: ...
@property
- @abc.abstractclassmethod
+ @abstractmethod
def name(self) -> str: ...
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state)
-
nrows, ncols = underlying.shape
- if ncols != 1:
- raise ValueError("Parametrize works only with column vectors")
# FIXME: not sure this behaviour is intuitive, discuss
if v := state.get_variable_from_name_or(self.name, if_not_present=False):
start, end = state.offset_dict[self.name]
- if nrows != (end - start):
+ if (nrows * ncols) != (end - start):
raise ValueError("Cannot parametrize {self.underlying} with variable {v} "
"found in state object, because its shape {(nrow, ncols)} "
"does not match ({self.underlying.shape}). ")
else:
- v = init_variable(self.name, shape=(nrows, 1))
+ v = init_variable(self.name, shape=(nrows, ncols))
state = state.register(v)
- p = PolyMatrixDict({
- MatrixIndex(row, 0): PolyDict({
+ p = PolyMatrixDict.empty()
+ indices = state.get_indices(v)
+ for (row, col), index in zip(product(range(nrows), range(ncols)), indices):
+ p[row, col] = PolyDict({
MonomialIndex((VariableIndex(index, 1),)): 1
})
- for row, index in enumerate(state.get_indices(v))
- })
return state, init_poly_matrix(p, underlying.shape)
diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
deleted file mode 100644
index 175a5c3..0000000
--- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import abc
-import dataclasses
-
-from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expressionstate.mixins import ExpressionStateMixin
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
-
-
-# remove?
-class ParametrizeMatrixExprMixin(ExpressionBaseMixin):
- @property
- @abc.abstractclassmethod
- def underlying(self) -> ExpressionBaseMixin: ...
-
- @property
- @abc.abstractclassmethod
- def name(self) -> str: ...
-
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionStateMixin,
- ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- # cache polymatrix to not re-parametrize at every apply call
- if self in state.cache:
- return state, state.cache[self]
-
- state, underlying = self.underlying.apply(state)
-
- assert underlying.shape[1] == 1
-
- poly_matrix_data = {}
- var_index = 0
-
- for row in range(underlying.shape[0]):
- for _ in range(row, underlying.shape[0]):
- poly_matrix_data[var_index, 0] = {
- ((state.n_param + var_index, 1),): 1.0
- }
-
- var_index += 1
-
- state = state.register(
- key=self,
- n_param=var_index,
- )
-
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=(var_index, 1),
- )
-
- state = dataclasses.replace(
- state,
- cache=state.cache | {self: poly_matrix},
- )
-
- return state, poly_matrix
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index b1de626..c24b1d2 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -3,10 +3,6 @@ import sympy
import numpy as np
from polymatrix.expression.expression import Expression
-from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
-from polymatrix.expression.mixins.parametrizematrixexprmixin import (
- ParametrizeMatrixExprMixin,
-)
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.statemonad.init import init_state_monad
from polymatrix.statemonad.mixins import StateMonadMixin