summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-25 10:54:56 +0200
committerNao Pross <np@0hm.ch>2024-05-25 11:31:42 +0200
commitd494f78e904e751772bad16d4d58eae59f76c3b8 (patch)
tree469650714ac745cb441b864ab36f77c3224a80ca
parentImprove slicing, allow M[i] as shorthand for M[i, :] (diff)
downloadpolymatrix-d494f78e904e751772bad16d4d58eae59f76c3b8.tar.gz
polymatrix-d494f78e904e751772bad16d4d58eae59f76c3b8.zip
Fix bug in LowerTriangularExpr, expose lower_triangular and block_diag
-rw-r--r--polymatrix/__init__.py9
-rw-r--r--polymatrix/expression/__init__.py5
-rw-r--r--polymatrix/expression/mixins/lowertriangularexprmixin.py7
3 files changed, 16 insertions, 5 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 62fc4a6..33ad143 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -15,6 +15,8 @@ from polymatrix.expression import (
v_stack as internal_v_stack,
h_stack as internal_h_stack,
product as internal_product,
+ block_diag as internal_block_diag,
+ lower_triangular as internal_lower_triangular,
)
from polymatrix.expression.to import (
@@ -34,12 +36,15 @@ make_state = init_expression_state
v_stack = internal_v_stack
h_stack = internal_h_stack
product = internal_product
+block_diag = internal_block_diag
+lower_triangular = internal_lower_triangular
to_constant_repr = internal_to_constant
-to_constant = internal_to_constant
to_sympy_repr = internal_to_sympy
-to_sympy = internal_to_sympy
to_matrix_repr = from_polymatrix
+
+to_constant = internal_to_constant
+to_sympy = internal_to_sympy
to_dense = from_polymatrix
to_affine = to_affine_expression
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index f8b5b74..2864fc5 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -56,3 +56,8 @@ def product(
stack=get_stack_lines(),
)
)
+
+
+def lower_triangular(vector: Expression):
+ return init_expression(
+ underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector))
diff --git a/polymatrix/expression/mixins/lowertriangularexprmixin.py b/polymatrix/expression/mixins/lowertriangularexprmixin.py
index 74368a4..5bdfb3e 100644
--- a/polymatrix/expression/mixins/lowertriangularexprmixin.py
+++ b/polymatrix/expression/mixins/lowertriangularexprmixin.py
@@ -44,14 +44,15 @@ class LowerTriangularExprMixin(ExpressionBaseMixin):
n_float = .5 * (-1 + sqrt(1 + 8 * N))
if not isclose(int(n_float), n_float):
raise ValueError("To construct a n x n lower triangular matrix, the vector "
- "must be n * (n + 1) / 2 column dimensional.")
+ "must be an n * (n + 1) / 2 dimensional column.")
n = int(n_float)
p = PolyMatrixDict.empty()
for row in range(n):
- for col in range(row):
- p[row, col] = u.at(row + col, 1)
+ i = row * (row + 1) // 2
+ for col in range(row + 1):
+ p[row, col] = u.at(i + col, 0)
return state, init_poly_matrix(p, shape=(n, n))