diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/impl.py | 6 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/mixins/concatenateexprmixin.py | 29 |
3 files changed, 37 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index a1a4269..176fe39 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -11,6 +11,7 @@ from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin +from polymatrix.expression.mixins.concatenateexprmixin import ConcatenateExprMixin from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin @@ -97,6 +98,11 @@ class CombinationsExprImpl(CombinationsExprMixin): @dataclassabc.dataclassabc(frozen=True) +class ConcatenateExprImpl(ConcatenateExprMixin): + blocks: tuple[tuple[ExpressionBaseMixin, ...], ...] + + +@dataclassabc.dataclassabc(frozen=True) class DerivativeExprImpl(DerivativeExprMixin): underlying: ExpressionBaseMixin variables: tuple diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index e65be9e..7c3f5fa 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -59,6 +59,8 @@ def init_combinations_expr( ) +def init_concatenate_expr(blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]): + return polymatrix.expression.impl.ConcatenateExprImpl(blocks) def init_diag_expr( diff --git a/polymatrix/expression/mixins/concatenateexprmixin.py b/polymatrix/expression/mixins/concatenateexprmixin.py new file mode 100644 index 0000000..62f40a3 --- /dev/null +++ b/polymatrix/expression/mixins/concatenateexprmixin.py @@ -0,0 +1,29 @@ +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.init import init_block_poly_matrix + + +class ConcatenateExprMixin(ExpressionBaseMixin): + """ Concatenate matrices. """ + + @property + @abstractmethod + def blocks(self) -> tuple[tuple[ExpressionBaseMixin, ...], ...]: + """ Matrices to concatenate, stored in row major order. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + blocks = {} + for row, row_blocks in enumerate(self.blocks): + for col, block_expr in enumerate(row_blocks): + state, block = block_expr.apply(state) + block_nrows, block_ncols = block.shape + blocks[range(row, row + block_nrows), + range(col, col + block_ncols)] = block + + return state, init_block_poly_matrix(blocks) + |