diff options
-rw-r--r-- | polymatrix/polymatrix/impl.py | 6 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 36 | ||||
-rw-r--r-- | polymatrix/polymatrix/mixins.py | 40 |
3 files changed, 79 insertions, 3 deletions
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py index 816e4e1..37e6429 100644 --- a/polymatrix/polymatrix/impl.py +++ b/polymatrix/polymatrix/impl.py @@ -1,7 +1,7 @@ import dataclassabc from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.polymatrix.mixins import PolyMatrixMixin, BroadcastPolyMatrixMixin, SlicePolyMatrixMixin, PolyMatrixAsAffineExpressionMixin +from polymatrix.polymatrix.mixins import PolyMatrixMixin, BroadcastPolyMatrixMixin, BlockPolyMatrixMixin, SlicePolyMatrixMixin, PolyMatrixAsAffineExpressionMixin from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex @@ -16,6 +16,10 @@ class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin): data: PolyDict shape: tuple[int, int] +@dataclassabc.dataclassabc(frozen=True) +class BlockPolyMatrixImpl(BlockPolyMatrixMixin): + blocks: dict[tuple[range, range], PolyMatrixMixin] + shape: tuple[int, int] @dataclassabc.dataclassabc(frozen=True) class SlicePolyMatrixImpl(SlicePolyMatrixMixin): diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 0778a50..4c89c29 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -5,12 +5,18 @@ import numpy as np from typing import TYPE_CHECKING -from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl, SlicePolyMatrixImpl, PolyMatrixAsAffineExpressionImpl +from polymatrix.polymatrix.impl import ( + BroadcastPolyMatrixImpl, + BlockPolyMatrixImpl, + PolyMatrixImpl, + SlicePolyMatrixImpl, + PolyMatrixAsAffineExpressionImpl) + from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MatrixIndex, MonomialIndex, VariableIndex from polymatrix.polymatrix.mixins import PolyMatrixAsAffineExpressionMixin if TYPE_CHECKING: - from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin + from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin, SlicePolyMatrixMixin # FIXME: use polymatrix.typing @@ -53,6 +59,32 @@ def init_broadcast_poly_matrix( ) +def init_block_poly_matrix(blocks: dict[tuple[range, range], PolyMatrixMixin]) -> BlockPolyMatrixMixin: + """ + Blocks are given as dictionary from row and column ranges to polymatrices. + """ + # Check that ranges are correct, and compute shape of whole thing + nrows, ncols = 0, 0 + + for (row_range, col_range), block in blocks.items(): + block_nrows, block_ncols = block.shape + if (row_range.stop - row_range.start) != block_nrows: + raise ValueError(f"Row range {row_range} given for block " + f"with shape {block.shape} is incorrect.") + + if nrows < row_range.stop: + nrows = row_range.stop + + if (col_range.stop - col_range.start) != block_ncols: + raise ValueError(f"Column range {col_range} given for block " + f"with shape {block.shape} is incorrect.") + + if ncols < col_range.stop: + ncols = col_range.stop + + return BlockPolyMatrixImpl(blocks, shape=(nrows, ncols)) + + def init_slice_poly_matrix( reference: PolyMatrixMixin, slices: tuple[int | Iterable[int], int | Iterable[int]] diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index fa1644e..4cb89b3 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -121,6 +121,46 @@ class BroadcastPolyMatrixMixin(PolyMatrixMixin, ABC): return self.data or None +class BlockPolyMatrixMixin(PolyMatrixMixin, ABC): + """ + Polymatrix that is made of other blocks, which themselves are polymatrices. + In other words concatenation of matrices, a generalization of vstack, block + diagonal, etc. If a block is not present it is filled with zeros. + """ + # Something like this + # + # [[ block11, block12, ..., block1m ] + # [ block21, block22, ..., block2m ] + # ... + # [ blockn1, blockn2, ..., blocknm ]] + + @property + @abstractmethod + def blocks(self) -> dict[tuple[range, range], PolyMatrixMixin]: + """ The blocks. """ + # tuple of ranges are row and column ranges that are covered by each block + # This class does not check that the blocks are compatible with each other! + # That must be done at time of construction, see init_block_poly_matrix + + @override + def at(self, row: int, col: int) -> PolyDict: + if not (0 <= row < self.shape[0]): + raise IndexError(f"Row {row} out of range, shape is {self.shape}") + + if not (0 <= col < self.shape[1]): + raise IndexError(f"Column {col} out of range, shape is {self.shape}") + + for (row_range, col_range), pm in self.blocks.items(): + if row in row_range and col in col_range: + block_row = row - row_range.start + block_col = col - col_range.start + return pm.at(block_row, block_col) + + # (row, col) is within bounds, but there is no block for this, so it is + # filled with zeros + return PolyDict.empty() + + class SlicePolyMatrixMixin(PolyMatrixMixin, ABC): """ Slice of a poly matrix. """ @property |