diff options
-rw-r--r-- | polymatrix/expression/impl.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 34 | ||||
-rw-r--r-- | polymatrix/expression/mixins/blockdiagexprmixin.py | 82 |
3 files changed, 32 insertions, 86 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 8eb1857..47829e0 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -102,7 +102,7 @@ class ARangeExprImpl(ARangeExprMixin): @dataclassabc.dataclassabc(frozen=True) class BlockDiagExprImpl(BlockDiagExprMixin): - underlying: tuple[ExpressionBaseMixin] + blocks: tuple[ExpressionBaseMixin, ...] @dataclassabc.dataclassabc(frozen=True) diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 1db855f..31434f7 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -37,20 +37,12 @@ def init_arange_expr( return polymatrix.expression.impl.ARangeExprImpl(start, stop, step) -def init_block_diag_expr( - underlying: tuple, -): - return polymatrix.expression.impl.BlockDiagExprImpl( - underlying=underlying, - ) +def init_block_diag_expr(blocks: tuple[ExpressionBaseMixin, ...]): + return polymatrix.expression.impl.BlockDiagExprImpl(blocks) -def init_cache_expr( - underlying: ExpressionBaseMixin, -): - return polymatrix.expression.impl.CacheExprImpl( - underlying=underlying, - ) +def init_cache_expr(underlying: ExpressionBaseMixin): + return polymatrix.expression.impl.CacheExprImpl(underlying=underlying) def init_combinations_expr( @@ -70,18 +62,13 @@ def init_concatenate_expr(blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]): return polymatrix.expression.impl.ConcatenateExprImpl(blocks) -def init_diag_expr( - underlying: ExpressionBaseMixin, -): +def init_diag_expr(underlying: ExpressionBaseMixin): return polymatrix.expression.impl.DiagExprImpl( underlying=underlying, ) -def init_divergence_expr( - underlying: ExpressionBaseMixin, - variables: tuple, -): +def init_divergence_expr(underlying: ExpressionBaseMixin, variables: tuple): return polymatrix.expression.impl.DivergenceExprImpl( underlying=underlying, variables=variables, @@ -132,17 +119,14 @@ def init_eval_expr( ) -def init_eye_expr( - variable: ExpressionBaseMixin, -): +# FIXME: make it take a shape not a variable +def init_eye_expr(variable: ExpressionBaseMixin): return polymatrix.expression.impl.EyeExprImpl( variable=variable, ) -def init_from_symmetric_matrix_expr( - underlying: ExpressionBaseMixin, -): +def init_from_symmetric_matrix_expr(underlying: ExpressionBaseMixin): return polymatrix.expression.impl.FromSymmetricMatrixExprImpl( underlying=underlying, ) diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index 45e5e1b..248f5b7 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -1,79 +1,41 @@ -import abc -import itertools -import dataclassabc +from abc import abstractmethod +from typing_extensions import override -from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.polymatrix.index import PolyDict +from polymatrix.polymatrix.mixins import BlockPolyMatrixMixin +from polymatrix.polymatrix.init import init_block_poly_matrix from polymatrix.expressionstate import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin class BlockDiagExprMixin(ExpressionBaseMixin): """ - Create a block diagonal polymatrix from provided polymatrices + Create a block diagonal matrix from a tuple of matrices + :: - [[x1]], [[x2], [x3]] -> [[x1, 0], [0, x2], [0, x3]]. + (A, B, C) -> [ A ] + [ B ] + [ C ] """ @property - @abc.abstractmethod - def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ... + @abstractmethod + def blocks(self) -> tuple[ExpressionBaseMixin, ...]: ... - # overwrites the abstract method of `ExpressionBaseMixin` + @override def apply( self, state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: # FIXME: incorrect return type - all_underlying = [] - for expr in self.underlying: - state, polymat = expr.apply(state=state) - all_underlying.append(polymat) + ) -> tuple[ExpressionState, BlockPolyMatrixMixin]: - # FIXME: move to polymatrix module - @dataclassabc.dataclassabc(frozen=True) - class BlockDiagPolyMatrix(PolyMatrixMixin): - all_underlying: tuple[PolyMatrixMixin] - underlying_row_col_range: tuple[tuple[int, int], ...] - shape: tuple[int, int] + d, row, col = {}, 0, 0 + for block in self.blocks: + state, pm = block.apply(state) + block_nrows, block_ncols = pm.shape - def at(self, row: int, col: int) -> PolyDict: - return self.get_poly(row, col) or PolyDict.empty() + d[range(row, row + block_nrows), + range(col, col + block_ncols)] = pm - # FIXME: typing problems - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - for polymatrix, ((row_start, col_start), (row_end, col_end)) in zip( - self.all_underlying, self.underlying_row_col_range - ): - if row_start <= row < row_end: - if col_start <= col < col_end: - return polymatrix.get_poly( - row=row - row_start, - col=col - col_start, - ) + row += block_nrows + col += block_ncols - else: - return None - - # NP: Do not raise generic expression, specialize error - raise Exception(f"row {row} is out of bounds") - - underlying_row_col_range = tuple( - itertools.pairwise( - itertools.accumulate( - (expr.shape for expr in all_underlying), - lambda acc, v: tuple(v1 + v2 for v1, v2 in zip(acc, v)), - initial=(0, 0), - ) - ) - ) - - shape = underlying_row_col_range[-1][1] - - polymatrix = BlockDiagPolyMatrix( - all_underlying=all_underlying, - shape=shape, - underlying_row_col_range=underlying_row_col_range, - ) - - return state, polymatrix + return state, init_block_poly_matrix(d) |