summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/impl.py6
-rw-r--r--polymatrix/polymatrix/init.py36
-rw-r--r--polymatrix/polymatrix/mixins.py40
3 files changed, 79 insertions, 3 deletions
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index 816e4e1..37e6429 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -1,7 +1,7 @@
import dataclassabc
from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.mixins import PolyMatrixMixin, BroadcastPolyMatrixMixin, SlicePolyMatrixMixin, PolyMatrixAsAffineExpressionMixin
+from polymatrix.polymatrix.mixins import PolyMatrixMixin, BroadcastPolyMatrixMixin, BlockPolyMatrixMixin, SlicePolyMatrixMixin, PolyMatrixAsAffineExpressionMixin
from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex
@@ -16,6 +16,10 @@ class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin):
data: PolyDict
shape: tuple[int, int]
+@dataclassabc.dataclassabc(frozen=True)
+class BlockPolyMatrixImpl(BlockPolyMatrixMixin):
+ blocks: dict[tuple[range, range], PolyMatrixMixin]
+ shape: tuple[int, int]
@dataclassabc.dataclassabc(frozen=True)
class SlicePolyMatrixImpl(SlicePolyMatrixMixin):
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 0778a50..4c89c29 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -5,12 +5,18 @@ import numpy as np
from typing import TYPE_CHECKING
-from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl, SlicePolyMatrixImpl, PolyMatrixAsAffineExpressionImpl
+from polymatrix.polymatrix.impl import (
+ BroadcastPolyMatrixImpl,
+ BlockPolyMatrixImpl,
+ PolyMatrixImpl,
+ SlicePolyMatrixImpl,
+ PolyMatrixAsAffineExpressionImpl)
+
from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MatrixIndex, MonomialIndex, VariableIndex
from polymatrix.polymatrix.mixins import PolyMatrixAsAffineExpressionMixin
if TYPE_CHECKING:
- from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin
+ from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin, SlicePolyMatrixMixin
# FIXME: use polymatrix.typing
@@ -53,6 +59,32 @@ def init_broadcast_poly_matrix(
)
+def init_block_poly_matrix(blocks: dict[tuple[range, range], PolyMatrixMixin]) -> BlockPolyMatrixMixin:
+ """
+ Blocks are given as dictionary from row and column ranges to polymatrices.
+ """
+ # Check that ranges are correct, and compute shape of whole thing
+ nrows, ncols = 0, 0
+
+ for (row_range, col_range), block in blocks.items():
+ block_nrows, block_ncols = block.shape
+ if (row_range.stop - row_range.start) != block_nrows:
+ raise ValueError(f"Row range {row_range} given for block "
+ f"with shape {block.shape} is incorrect.")
+
+ if nrows < row_range.stop:
+ nrows = row_range.stop
+
+ if (col_range.stop - col_range.start) != block_ncols:
+ raise ValueError(f"Column range {col_range} given for block "
+ f"with shape {block.shape} is incorrect.")
+
+ if ncols < col_range.stop:
+ ncols = col_range.stop
+
+ return BlockPolyMatrixImpl(blocks, shape=(nrows, ncols))
+
+
def init_slice_poly_matrix(
reference: PolyMatrixMixin,
slices: tuple[int | Iterable[int], int | Iterable[int]]
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index fa1644e..4cb89b3 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -121,6 +121,46 @@ class BroadcastPolyMatrixMixin(PolyMatrixMixin, ABC):
return self.data or None
+class BlockPolyMatrixMixin(PolyMatrixMixin, ABC):
+ """
+ Polymatrix that is made of other blocks, which themselves are polymatrices.
+ In other words concatenation of matrices, a generalization of vstack, block
+ diagonal, etc. If a block is not present it is filled with zeros.
+ """
+ # Something like this
+ #
+ # [[ block11, block12, ..., block1m ]
+ # [ block21, block22, ..., block2m ]
+ # ...
+ # [ blockn1, blockn2, ..., blocknm ]]
+
+ @property
+ @abstractmethod
+ def blocks(self) -> dict[tuple[range, range], PolyMatrixMixin]:
+ """ The blocks. """
+ # tuple of ranges are row and column ranges that are covered by each block
+ # This class does not check that the blocks are compatible with each other!
+ # That must be done at time of construction, see init_block_poly_matrix
+
+ @override
+ def at(self, row: int, col: int) -> PolyDict:
+ if not (0 <= row < self.shape[0]):
+ raise IndexError(f"Row {row} out of range, shape is {self.shape}")
+
+ if not (0 <= col < self.shape[1]):
+ raise IndexError(f"Column {col} out of range, shape is {self.shape}")
+
+ for (row_range, col_range), pm in self.blocks.items():
+ if row in row_range and col in col_range:
+ block_row = row - row_range.start
+ block_col = col - col_range.start
+ return pm.at(block_row, block_col)
+
+ # (row, col) is within bounds, but there is no block for this, so it is
+ # filled with zeros
+ return PolyDict.empty()
+
+
class SlicePolyMatrixMixin(PolyMatrixMixin, ABC):
""" Slice of a poly matrix. """
@property