summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-25 11:45:41 +0200
committerNao Pross <np@0hm.ch>2024-05-25 11:45:41 +0200
commit84aac6bcdbf826b0f5107d3f0cc2f57106d390a0 (patch)
tree74bce129b89312d6d7f0d194566c8647bdda67ac
parentFix bug in LowerTriangularExpr, expose lower_triangular and block_diag (diff)
downloadpolymatrix-84aac6bcdbf826b0f5107d3f0cc2f57106d390a0.tar.gz
polymatrix-84aac6bcdbf826b0f5107d3f0cc2f57106d390a0.zip
Fix ConcatenateExpr bug and expose concatenate
-rw-r--r--polymatrix/__init__.py2
-rw-r--r--polymatrix/expression/__init__.py5
-rw-r--r--polymatrix/expression/mixins/concatenateexprmixin.py11
3 files changed, 16 insertions, 2 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 33ad143..cd1a4e1 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -15,6 +15,7 @@ from polymatrix.expression import (
v_stack as internal_v_stack,
h_stack as internal_h_stack,
product as internal_product,
+ concatenate as internal_concatenate,
block_diag as internal_block_diag,
lower_triangular as internal_lower_triangular,
)
@@ -36,6 +37,7 @@ make_state = init_expression_state
v_stack = internal_v_stack
h_stack = internal_h_stack
product = internal_product
+concatenate = internal_concatenate
block_diag = internal_block_diag
lower_triangular = internal_lower_triangular
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 2864fc5..90cf425 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -58,6 +58,11 @@ def product(
)
+def concatenate(arrays: Iterable[Iterable[Expression]]):
+ return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl(
+ tuple(tuple(expr.underlying for expr in row) for row in arrays)))
+
+
def lower_triangular(vector: Expression):
return init_expression(
underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector))
diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py
index 62f40a3..500a700 100644
--- a/polymatrix/expression/mixins/concatenateexprmixin.py
+++ b/polymatrix/expression/mixins/concatenateexprmixin.py
@@ -18,12 +18,19 @@ class ConcatenateExprMixin(ExpressionBaseMixin):
@override
def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
blocks = {}
- for row, row_blocks in enumerate(self.blocks):
- for col, block_expr in enumerate(row_blocks):
+ row = 0
+ for row_blocks in self.blocks:
+ col = 0
+ for block_expr in row_blocks:
state, block = block_expr.apply(state)
block_nrows, block_ncols = block.shape
+
+ # init_block_polymatrix will check if the ranges are correct
blocks[range(row, row + block_nrows),
range(col, col + block_ncols)] = block
+ col += block_ncols
+ row += block_nrows
+
return state, init_block_poly_matrix(blocks)