1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
import abc
import itertools
import dataclassabc
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
class BlockDiagExprMixin(ExpressionBaseMixin):
"""
Create a block diagonal polymatrix from provided polymatrices
[[x1]], [[x2], [x3]] -> [[x1, 0], [0, x2], [0, x3]].
"""
@property
@abc.abstractmethod
def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ...
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
all_underlying = []
for expr in self.underlying:
state, polymat = expr.apply(state=state)
all_underlying.append(polymat)
@dataclassabc.dataclassabc(frozen=True)
class BlockDiagPolyMatrix(PolyMatrixMixin):
all_underlying: tuple[PolyMatrixMixin]
underlying_row_col_range: tuple[tuple[int, int], ...]
shape: tuple[int, int]
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
for polymatrix, ((row_start, col_start), (row_end, col_end)) in zip(
self.all_underlying, self.underlying_row_col_range
):
if row_start <= row < row_end:
if col_start <= col < col_end:
return polymatrix.get_poly(
row=row - row_start,
col=col - col_start,
)
else:
return None
raise Exception(f"row {row} is out of bounds")
underlying_row_col_range = tuple(
itertools.pairwise(
itertools.accumulate(
(expr.shape for expr in all_underlying),
lambda acc, v: tuple(v1 + v2 for v1, v2 in zip(acc, v)),
initial=(0, 0),
)
)
)
shape = underlying_row_col_range[-1][1]
polymatrix = BlockDiagPolyMatrix(
all_underlying=all_underlying,
shape=shape,
underlying_row_col_range=underlying_row_col_range,
)
return state, polymatrix
|