diff options
-rw-r--r-- | polymatrix/expression/mixins/maxexprmixin.py | 102 | ||||
-rw-r--r-- | polymatrix/polymatrix/index.py | 5 |
2 files changed, 82 insertions, 25 deletions
diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 60c1832..30d3917 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -1,44 +1,96 @@ -import abc +from abc import abstractmethod +from typing_extensions import override -from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex +from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin -# remove? class MaxExprMixin(ExpressionBaseMixin): + """ + Keep only the maximum of an element of vector, returns a scalar. If the + underlying is a matrix then the maximum is interpreted row-wise, so it + returns a column vector. + + The vector (or matrix) may not contain polynomials. + """ + @property - @abc.abstractclassmethod + @abstractmethod def underlying(self) -> ExpressionBaseMixin: ... - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrixMixin]: + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, underlying = self.underlying.apply(state) + nrows, ncols = underlying.shape + + # row vector + if nrows == 1: + maximum = None + for col in range(ncols): + if poly := underlying.at(0, col): + if poly.degree > 0: + raise ValueError("Cannot take maximum of non-costant vector. " + f"Entry at {(0,col)} has degree {poly.degree}.") + + if maximum is None or poly.constant() > maximum: + maximum = poly.constant() + + assert maximum is not None + + shape = (1, 1) + p = PolyMatrixDict({ + MatrixIndex(0, 0): PolyDict({ + MonomialIndex.constant(): maximum + }) + }) - poly_matrix_data = {} + # column vector + elif ncols == 1: + maximum = None + for row in range(nrows): + if poly := underlying.at(row, col): + if poly.degree > 0: + raise ValueError("Cannot take maximum of non-costant vector. " + f"Entry at {(row, 0)} has degree {poly.degree}.") - for row in range(underlying.shape[0]): + if maximum is None or poly.constant() > maximum: + maximum = poly.constant() - def gen_values(): - for col in range(underlying.shape[1]): - polynomial = underlying.get_poly(row, col) - if polynomial is None: - continue + assert maximum is not None + + shape = (1, 1) + p = PolyMatrixDict({ + MatrixIndex(0, 0): PolyDict({ + MonomialIndex.constant(): maximum + }) + }) - yield polynomial[tuple()] + # matrix + else: + maxima = [] + for row in range(nrows): + maximum = None + for col in range(ncols): + if poly := underlying.at(row, col): + if poly.degree > 0: + raise ValueError("Cannot take maximum of non-costant vector. " + f"Entry at {(row, 0)} has degree {poly.degree}.") - values = tuple(gen_values()) + if maximum is None or poly.constant() > maximum: + maximum = poly.constant() - if 0 < len(values): - poly_matrix_data[row, 0] = {tuple(): max(values)} + assert maximum is not None + maxima.append(maximum) - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=(underlying.shape[0], 1), - ) + shape = (nrows, 1) + p = PolyMatrixDict({ + MatrixIndex(i, 0): PolyDict({ + MonomialIndex.constant(): row_max + }) + for i, row_max in enumerate(maxima) + }) - return state, poly_matrix + return state, init_poly_matrix(p, shape) diff --git a/polymatrix/polymatrix/index.py b/polymatrix/polymatrix/index.py index bab991f..c68e23f 100644 --- a/polymatrix/polymatrix/index.py +++ b/polymatrix/polymatrix/index.py @@ -159,6 +159,11 @@ class MonomialIndex(tuple[VariableIndex]): class PolyDict(UserDict[MonomialIndex, int | float]): """ Polynomial, stored as a dictionary. """ + @property + def degree(self) -> int: + """ Degree of the polynomial. """ + return max(m.degree for m in self.monomials()) + @staticmethod def empty() -> PolyDict: return PolyDict({}) |