summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/impl.py6
-rw-r--r--polymatrix/expression/init.py2
-rw-r--r--polymatrix/expression/mixins/concatenateexprmixin.py29
3 files changed, 37 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index a1a4269..176fe39 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -11,6 +11,7 @@ from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin
from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin
from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin
+from polymatrix.expression.mixins.concatenateexprmixin import ConcatenateExprMixin
from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin
from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin
@@ -97,6 +98,11 @@ class CombinationsExprImpl(CombinationsExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class ConcatenateExprImpl(ConcatenateExprMixin):
+ blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]
+
+
+@dataclassabc.dataclassabc(frozen=True)
class DerivativeExprImpl(DerivativeExprMixin):
underlying: ExpressionBaseMixin
variables: tuple
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index e65be9e..7c3f5fa 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -59,6 +59,8 @@ def init_combinations_expr(
)
+def init_concatenate_expr(blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]):
+ return polymatrix.expression.impl.ConcatenateExprImpl(blocks)
def init_diag_expr(
diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py
new file mode 100644
index 0000000..62f40a3
--- /dev/null
+++ b/polymatrix/expression/mixins/concatenateexprmixin.py
@@ -0,0 +1,29 @@
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_block_poly_matrix
+
+
+class ConcatenateExprMixin(ExpressionBaseMixin):
+ """ Concatenate matrices. """
+
+ @property
+ @abstractmethod
+ def blocks(self) -> tuple[tuple[ExpressionBaseMixin, ...], ...]:
+ """ Matrices to concatenate, stored in row major order. """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ blocks = {}
+ for row, row_blocks in enumerate(self.blocks):
+ for col, block_expr in enumerate(row_blocks):
+ state, block = block_expr.apply(state)
+ block_nrows, block_ncols = block.shape
+ blocks[range(row, row + block_nrows),
+ range(col, col + block_ncols)] = block
+
+ return state, init_block_poly_matrix(blocks)
+