diff options
-rw-r--r-- | polymatrix/expression/impl.py | 6 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/mixins/lowertriangularexprmixin.py | 59 | ||||
-rw-r--r-- | polymatrix/polymatrix/mixins.py | 5 |
4 files changed, 69 insertions, 5 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 2eeea39..a1a4269 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -31,6 +31,7 @@ from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin +from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin @@ -233,6 +234,11 @@ class LinearMonomialsExprImpl(LinearMonomialsExprMixin): @dataclassabc.dataclassabc(frozen=True) +class LowerTriangularExprImpl(LowerTriangularExprMixin): + underlying: ExpressionBaseMixin + + +@dataclassabc.dataclassabc(frozen=True) class LegendreSeriesImpl(LegendreSeriesMixin): underlying: ExpressionBaseMixin degrees: tuple[int, ...] | None diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 30fd3e1..825bc10 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -275,6 +275,10 @@ def init_linear_matrix_in_expr( ) +def init_lower_triangular_expr(underlying: ExpressionBaseMixin): + return polymatrix.expression.impl.LowerTriangularExprImpl(underlying) + + def init_matrix_mult_expr( left: ExpressionBaseMixin, right: ExpressionBaseMixin, diff --git a/polymatrix/expression/mixins/lowertriangularexprmixin.py b/polymatrix/expression/mixins/lowertriangularexprmixin.py new file mode 100644 index 0000000..74368a4 --- /dev/null +++ b/polymatrix/expression/mixins/lowertriangularexprmixin.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from abc import abstractmethod +from math import sqrt, isclose +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.index import PolyMatrixDict +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin + +class LowerTriangularExprMixin(ExpressionBaseMixin): + @property + @abstractmethod + def underlying(self) -> ExpressionBaseMixin: + r""" + Construct a lower triangular matrix from a vector. + + For a :math:`n\times n` lower diagonal matrix, this is a :math:`n * (n + + 1) / 2` dimensional vector containing the entries. The lower diagonal + matrix is filled from top to bottom so given the vector :math:`z` the + lower triangular matrix is + + .. math:: + Z = \begin{bmatrix} + z_1 \\ + z_2 & z_3 \\ + z_4 & z_5 & z_6 \\ + \vdots & & & \ddots \\ + \end{bmatrix}. + + """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + state, u = self.underlying.apply(state) + + N, ncols = u.shape + if ncols > 1: + raise ValueError("Cannot construct lower triangular matrix from object " + f"with shape {u.shape}, it must be a column vector.") + + n_float = .5 * (-1 + sqrt(1 + 8 * N)) + if not isclose(int(n_float), n_float): + raise ValueError("To construct a n x n lower triangular matrix, the vector " + "must be n * (n + 1) / 2 column dimensional.") + + n = int(n_float) + p = PolyMatrixDict.empty() + + for row in range(n): + for col in range(row): + p[row, col] = u.at(row + col, 1) + + return state, init_poly_matrix(p, shape=(n, n)) + + + diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index 63eeb18..fa1644e 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -13,11 +13,6 @@ from typing_extensions import override from polymatrix.polymatrix.index import PolyDict, PolyMatrixDict, MatrixIndex, MonomialIndex, VariableIndex from polymatrix.utils.deprecation import deprecated -if TYPE_CHECKING: - from polymatrix.expression.mixins.variablemixin import VariableMixin - from polymatrix.expressionstate.mixins import ExpressionStateMixin - - class PolyMatrixMixin(ABC): """ Matrix with polynomial entries. |