diff options
author | Nao Pross <np@0hm.ch> | 2024-05-25 11:45:41 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-25 11:45:41 +0200 |
commit | 84aac6bcdbf826b0f5107d3f0cc2f57106d390a0 (patch) | |
tree | 74bce129b89312d6d7f0d194566c8647bdda67ac | |
parent | Fix bug in LowerTriangularExpr, expose lower_triangular and block_diag (diff) | |
download | polymatrix-84aac6bcdbf826b0f5107d3f0cc2f57106d390a0.tar.gz polymatrix-84aac6bcdbf826b0f5107d3f0cc2f57106d390a0.zip |
Fix ConcatenateExpr bug and expose concatenate
-rw-r--r-- | polymatrix/__init__.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/__init__.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/mixins/concatenateexprmixin.py | 11 |
3 files changed, 16 insertions, 2 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 33ad143..cd1a4e1 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -15,6 +15,7 @@ from polymatrix.expression import ( v_stack as internal_v_stack, h_stack as internal_h_stack, product as internal_product, + concatenate as internal_concatenate, block_diag as internal_block_diag, lower_triangular as internal_lower_triangular, ) @@ -36,6 +37,7 @@ make_state = init_expression_state v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product +concatenate = internal_concatenate block_diag = internal_block_diag lower_triangular = internal_lower_triangular diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 2864fc5..90cf425 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -58,6 +58,11 @@ def product( ) +def concatenate(arrays: Iterable[Iterable[Expression]]): + return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl( + tuple(tuple(expr.underlying for expr in row) for row in arrays))) + + def lower_triangular(vector: Expression): return init_expression( underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector)) diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py index 62f40a3..500a700 100644 --- a/polymatrix/expression/mixins/concatenateexprmixin.py +++ b/polymatrix/expression/mixins/concatenateexprmixin.py @@ -18,12 +18,19 @@ class ConcatenateExprMixin(ExpressionBaseMixin): @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: blocks = {} - for row, row_blocks in enumerate(self.blocks): - for col, block_expr in enumerate(row_blocks): + row = 0 + for row_blocks in self.blocks: + col = 0 + for block_expr in row_blocks: state, block = block_expr.apply(state) block_nrows, block_ncols = block.shape + + # init_block_polymatrix will check if the ranges are correct blocks[range(row, row + block_nrows), range(col, col + block_ncols)] = block + col += block_ncols + row += block_nrows + return state, init_block_poly_matrix(blocks) |