From b4888da7cfc9555553c87b10ba78bf3512827d43 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Mon, 20 May 2024 15:49:01 +0200
Subject: Create ConcatenateExprMixin

---
 polymatrix/expression/impl.py                      |  6 +++++
 polymatrix/expression/init.py                      |  2 ++
 .../expression/mixins/concatenateexprmixin.py      | 29 ++++++++++++++++++++++
 3 files changed, 37 insertions(+)
 create mode 100644 polymatrix/expression/mixins/concatenateexprmixin.py

diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index a1a4269..176fe39 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -11,6 +11,7 @@ from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin
 from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
 from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin
 from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin
+from polymatrix.expression.mixins.concatenateexprmixin import ConcatenateExprMixin
 from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
 from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin
 from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin
@@ -96,6 +97,11 @@ class CombinationsExprImpl(CombinationsExprMixin):
         return f"combinations({self.expression}, {self.degrees})"
 
 
+@dataclassabc.dataclassabc(frozen=True)
+class ConcatenateExprImpl(ConcatenateExprMixin):
+    blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]
+
+
 @dataclassabc.dataclassabc(frozen=True)
 class DerivativeExprImpl(DerivativeExprMixin):
     underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index e65be9e..7c3f5fa 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -59,6 +59,8 @@ def init_combinations_expr(
     )
 
 
+def init_concatenate_expr(blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]):
+    return polymatrix.expression.impl.ConcatenateExprImpl(blocks)
 
 
 def init_diag_expr(
diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py
new file mode 100644
index 0000000..62f40a3
--- /dev/null
+++ b/polymatrix/expression/mixins/concatenateexprmixin.py
@@ -0,0 +1,29 @@
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_block_poly_matrix
+
+
+class ConcatenateExprMixin(ExpressionBaseMixin):
+    """ Concatenate matrices. """
+
+    @property
+    @abstractmethod
+    def blocks(self) -> tuple[tuple[ExpressionBaseMixin, ...], ...]:
+        """ Matrices to concatenate, stored in row major order. """
+
+    @override
+    def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+        blocks = {}
+        for row, row_blocks in enumerate(self.blocks):
+            for col, block_expr in enumerate(row_blocks):
+                state, block = block_expr.apply(state)
+                block_nrows, block_ncols = block.shape
+                blocks[range(row, row + block_nrows),
+                       range(col, col + block_ncols)] = block
+
+        return state, init_block_poly_matrix(blocks)
+
-- 
cgit v1.2.1