summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-20 14:10:42 +0200
committerNao Pross <np@0hm.ch>2024-05-20 14:10:42 +0200
commitffa4d7eacc86981dcc1a68e06ffed54f1bd9037a (patch)
tree7ccf7992fb0b08df19cad2ced5da97abab0eb7f3
parentDelete broken substitute methods in Expression (diff)
downloadpolymatrix-ffa4d7eacc86981dcc1a68e06ffed54f1bd9037a.tar.gz
polymatrix-ffa4d7eacc86981dcc1a68e06ffed54f1bd9037a.zip
Create Lower triangular matrix expression
-rw-r--r--polymatrix/expression/impl.py6
-rw-r--r--polymatrix/expression/init.py4
-rw-r--r--polymatrix/expression/mixins/lowertriangularexprmixin.py59
-rw-r--r--polymatrix/polymatrix/mixins.py5
4 files changed, 69 insertions, 5 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 2eeea39..a1a4269 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -31,6 +31,7 @@ from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin
from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin
from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin
+from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
@@ -233,6 +234,11 @@ class LinearMonomialsExprImpl(LinearMonomialsExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class LowerTriangularExprImpl(LowerTriangularExprMixin):
+ underlying: ExpressionBaseMixin
+
+
+@dataclassabc.dataclassabc(frozen=True)
class LegendreSeriesImpl(LegendreSeriesMixin):
underlying: ExpressionBaseMixin
degrees: tuple[int, ...] | None
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 30fd3e1..825bc10 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -275,6 +275,10 @@ def init_linear_matrix_in_expr(
)
+def init_lower_triangular_expr(underlying: ExpressionBaseMixin):
+ return polymatrix.expression.impl.LowerTriangularExprImpl(underlying)
+
+
def init_matrix_mult_expr(
left: ExpressionBaseMixin,
right: ExpressionBaseMixin,
diff --git a/polymatrix/expression/mixins/lowertriangularexprmixin.py b/polymatrix/expression/mixins/lowertriangularexprmixin.py
new file mode 100644
index 0000000..74368a4
--- /dev/null
+++ b/polymatrix/expression/mixins/lowertriangularexprmixin.py
@@ -0,0 +1,59 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from math import sqrt, isclose
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.index import PolyMatrixDict
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+
+class LowerTriangularExprMixin(ExpressionBaseMixin):
+ @property
+ @abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ r"""
+ Construct a lower triangular matrix from a vector.
+
+ For a :math:`n\times n` lower diagonal matrix, this is a :math:`n * (n
+ + 1) / 2` dimensional vector containing the entries. The lower diagonal
+ matrix is filled from top to bottom so given the vector :math:`z` the
+ lower triangular matrix is
+
+ .. math::
+ Z = \begin{bmatrix}
+ z_1 \\
+ z_2 & z_3 \\
+ z_4 & z_5 & z_6 \\
+ \vdots & & & \ddots \\
+ \end{bmatrix}.
+
+ """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ state, u = self.underlying.apply(state)
+
+ N, ncols = u.shape
+ if ncols > 1:
+ raise ValueError("Cannot construct lower triangular matrix from object "
+ f"with shape {u.shape}, it must be a column vector.")
+
+ n_float = .5 * (-1 + sqrt(1 + 8 * N))
+ if not isclose(int(n_float), n_float):
+ raise ValueError("To construct a n x n lower triangular matrix, the vector "
+ "must be n * (n + 1) / 2 column dimensional.")
+
+ n = int(n_float)
+ p = PolyMatrixDict.empty()
+
+ for row in range(n):
+ for col in range(row):
+ p[row, col] = u.at(row + col, 1)
+
+ return state, init_poly_matrix(p, shape=(n, n))
+
+
+
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index 63eeb18..fa1644e 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -13,11 +13,6 @@ from typing_extensions import override
from polymatrix.polymatrix.index import PolyDict, PolyMatrixDict, MatrixIndex, MonomialIndex, VariableIndex
from polymatrix.utils.deprecation import deprecated
-if TYPE_CHECKING:
- from polymatrix.expression.mixins.variablemixin import VariableMixin
- from polymatrix.expressionstate.mixins import ExpressionStateMixin
-
-
class PolyMatrixMixin(ABC):
"""
Matrix with polynomial entries.