summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/maxexprmixin.py102
-rw-r--r--polymatrix/polymatrix/index.py5
2 files changed, 82 insertions, 25 deletions
diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py
index 60c1832..30d3917 100644
--- a/polymatrix/expression/mixins/maxexprmixin.py
+++ b/polymatrix/expression/mixins/maxexprmixin.py
@@ -1,44 +1,96 @@
-import abc
+from abc import abstractmethod
+from typing_extensions import override
-from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex
+from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.mixins import PolyMatrixMixin
-# remove?
class MaxExprMixin(ExpressionBaseMixin):
+ """
+ Keep only the maximum of an element of vector, returns a scalar. If the
+ underlying is a matrix then the maximum is interpreted row-wise, so it
+ returns a column vector.
+
+ The vector (or matrix) may not contain polynomials.
+ """
+
@property
- @abc.abstractclassmethod
+ @abstractmethod
def underlying(self) -> ExpressionBaseMixin: ...
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrixMixin]:
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state)
+ nrows, ncols = underlying.shape
+
+ # row vector
+ if nrows == 1:
+ maximum = None
+ for col in range(ncols):
+ if poly := underlying.at(0, col):
+ if poly.degree > 0:
+ raise ValueError("Cannot take maximum of non-costant vector. "
+ f"Entry at {(0,col)} has degree {poly.degree}.")
+
+ if maximum is None or poly.constant() > maximum:
+ maximum = poly.constant()
+
+ assert maximum is not None
+
+ shape = (1, 1)
+ p = PolyMatrixDict({
+ MatrixIndex(0, 0): PolyDict({
+ MonomialIndex.constant(): maximum
+ })
+ })
- poly_matrix_data = {}
+ # column vector
+ elif ncols == 1:
+ maximum = None
+ for row in range(nrows):
+ if poly := underlying.at(row, col):
+ if poly.degree > 0:
+ raise ValueError("Cannot take maximum of non-costant vector. "
+ f"Entry at {(row, 0)} has degree {poly.degree}.")
- for row in range(underlying.shape[0]):
+ if maximum is None or poly.constant() > maximum:
+ maximum = poly.constant()
- def gen_values():
- for col in range(underlying.shape[1]):
- polynomial = underlying.get_poly(row, col)
- if polynomial is None:
- continue
+ assert maximum is not None
+
+ shape = (1, 1)
+ p = PolyMatrixDict({
+ MatrixIndex(0, 0): PolyDict({
+ MonomialIndex.constant(): maximum
+ })
+ })
- yield polynomial[tuple()]
+ # matrix
+ else:
+ maxima = []
+ for row in range(nrows):
+ maximum = None
+ for col in range(ncols):
+ if poly := underlying.at(row, col):
+ if poly.degree > 0:
+ raise ValueError("Cannot take maximum of non-costant vector. "
+ f"Entry at {(row, 0)} has degree {poly.degree}.")
- values = tuple(gen_values())
+ if maximum is None or poly.constant() > maximum:
+ maximum = poly.constant()
- if 0 < len(values):
- poly_matrix_data[row, 0] = {tuple(): max(values)}
+ assert maximum is not None
+ maxima.append(maximum)
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=(underlying.shape[0], 1),
- )
+ shape = (nrows, 1)
+ p = PolyMatrixDict({
+ MatrixIndex(i, 0): PolyDict({
+ MonomialIndex.constant(): row_max
+ })
+ for i, row_max in enumerate(maxima)
+ })
- return state, poly_matrix
+ return state, init_poly_matrix(p, shape)
diff --git a/polymatrix/polymatrix/index.py b/polymatrix/polymatrix/index.py
index bab991f..c68e23f 100644
--- a/polymatrix/polymatrix/index.py
+++ b/polymatrix/polymatrix/index.py
@@ -159,6 +159,11 @@ class MonomialIndex(tuple[VariableIndex]):
class PolyDict(UserDict[MonomialIndex, int | float]):
""" Polynomial, stored as a dictionary. """
+ @property
+ def degree(self) -> int:
+ """ Degree of the polynomial. """
+ return max(m.degree for m in self.monomials())
+
@staticmethod
def empty() -> PolyDict:
return PolyDict({})