From 84aac6bcdbf826b0f5107d3f0cc2f57106d390a0 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sat, 25 May 2024 11:45:41 +0200 Subject: Fix ConcatenateExpr bug and expose concatenate --- polymatrix/__init__.py | 2 ++ polymatrix/expression/__init__.py | 5 +++++ polymatrix/expression/mixins/concatenateexprmixin.py | 11 +++++++++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 33ad143..cd1a4e1 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -15,6 +15,7 @@ from polymatrix.expression import ( v_stack as internal_v_stack, h_stack as internal_h_stack, product as internal_product, + concatenate as internal_concatenate, block_diag as internal_block_diag, lower_triangular as internal_lower_triangular, ) @@ -36,6 +37,7 @@ make_state = init_expression_state v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product +concatenate = internal_concatenate block_diag = internal_block_diag lower_triangular = internal_lower_triangular diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 2864fc5..90cf425 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -58,6 +58,11 @@ def product( ) +def concatenate(arrays: Iterable[Iterable[Expression]]): + return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl( + tuple(tuple(expr.underlying for expr in row) for row in arrays))) + + def lower_triangular(vector: Expression): return init_expression( underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector)) diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py index 62f40a3..500a700 100644 --- a/polymatrix/expression/mixins/concatenateexprmixin.py +++ b/polymatrix/expression/mixins/concatenateexprmixin.py @@ -18,12 +18,19 @@ class ConcatenateExprMixin(ExpressionBaseMixin): @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: blocks = {} - for row, row_blocks in enumerate(self.blocks): - for col, block_expr in enumerate(row_blocks): + row = 0 + for row_blocks in self.blocks: + col = 0 + for block_expr in row_blocks: state, block = block_expr.apply(state) block_nrows, block_ncols = block.shape + + # init_block_polymatrix will check if the ranges are correct blocks[range(row, row + block_nrows), range(col, col + block_ncols)] = block + col += block_ncols + row += block_nrows + return state, init_block_poly_matrix(blocks) -- cgit v1.2.1