From 84aac6bcdbf826b0f5107d3f0cc2f57106d390a0 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Sat, 25 May 2024 11:45:41 +0200
Subject: Fix ConcatenateExpr bug and expose concatenate

---
 polymatrix/__init__.py                               |  2 ++
 polymatrix/expression/__init__.py                    |  5 +++++
 polymatrix/expression/mixins/concatenateexprmixin.py | 11 +++++++++--
 3 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 33ad143..cd1a4e1 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -15,6 +15,7 @@ from polymatrix.expression import (
     v_stack as internal_v_stack,
     h_stack as internal_h_stack,
     product as internal_product,
+    concatenate as internal_concatenate,
     block_diag as internal_block_diag,
     lower_triangular as internal_lower_triangular,
 )
@@ -36,6 +37,7 @@ make_state = init_expression_state
 v_stack = internal_v_stack
 h_stack = internal_h_stack
 product = internal_product
+concatenate = internal_concatenate
 block_diag = internal_block_diag
 lower_triangular = internal_lower_triangular
 
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 2864fc5..90cf425 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -58,6 +58,11 @@ def product(
     )
 
 
+def concatenate(arrays: Iterable[Iterable[Expression]]):
+    return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl(
+        tuple(tuple(expr.underlying for expr in row) for row in arrays)))
+
+
 def lower_triangular(vector: Expression):
     return init_expression(
             underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector))
diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py
index 62f40a3..500a700 100644
--- a/polymatrix/expression/mixins/concatenateexprmixin.py
+++ b/polymatrix/expression/mixins/concatenateexprmixin.py
@@ -18,12 +18,19 @@ class ConcatenateExprMixin(ExpressionBaseMixin):
     @override
     def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
         blocks = {}
-        for row, row_blocks in enumerate(self.blocks):
-            for col, block_expr in enumerate(row_blocks):
+        row = 0
+        for row_blocks in self.blocks:
+            col = 0
+            for block_expr in row_blocks:
                 state, block = block_expr.apply(state)
                 block_nrows, block_ncols = block.shape
+
+                # init_block_polymatrix will check if the ranges are correct
                 blocks[range(row, row + block_nrows),
                        range(col, col + block_ncols)] = block
 
+                col += block_ncols
+            row += block_nrows
+
         return state, init_block_poly_matrix(blocks)
 
-- 
cgit v1.2.1