summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/impl.py2
-rw-r--r--polymatrix/expression/init.py34
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py82
3 files changed, 32 insertions, 86 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 8eb1857..47829e0 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -102,7 +102,7 @@ class ARangeExprImpl(ARangeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class BlockDiagExprImpl(BlockDiagExprMixin):
- underlying: tuple[ExpressionBaseMixin]
+ blocks: tuple[ExpressionBaseMixin, ...]
@dataclassabc.dataclassabc(frozen=True)
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 1db855f..31434f7 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -37,20 +37,12 @@ def init_arange_expr(
return polymatrix.expression.impl.ARangeExprImpl(start, stop, step)
-def init_block_diag_expr(
- underlying: tuple,
-):
- return polymatrix.expression.impl.BlockDiagExprImpl(
- underlying=underlying,
- )
+def init_block_diag_expr(blocks: tuple[ExpressionBaseMixin, ...]):
+ return polymatrix.expression.impl.BlockDiagExprImpl(blocks)
-def init_cache_expr(
- underlying: ExpressionBaseMixin,
-):
- return polymatrix.expression.impl.CacheExprImpl(
- underlying=underlying,
- )
+def init_cache_expr(underlying: ExpressionBaseMixin):
+ return polymatrix.expression.impl.CacheExprImpl(underlying=underlying)
def init_combinations_expr(
@@ -70,18 +62,13 @@ def init_concatenate_expr(blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]):
return polymatrix.expression.impl.ConcatenateExprImpl(blocks)
-def init_diag_expr(
- underlying: ExpressionBaseMixin,
-):
+def init_diag_expr(underlying: ExpressionBaseMixin):
return polymatrix.expression.impl.DiagExprImpl(
underlying=underlying,
)
-def init_divergence_expr(
- underlying: ExpressionBaseMixin,
- variables: tuple,
-):
+def init_divergence_expr(underlying: ExpressionBaseMixin, variables: tuple):
return polymatrix.expression.impl.DivergenceExprImpl(
underlying=underlying,
variables=variables,
@@ -132,17 +119,14 @@ def init_eval_expr(
)
-def init_eye_expr(
- variable: ExpressionBaseMixin,
-):
+# FIXME: make it take a shape not a variable
+def init_eye_expr(variable: ExpressionBaseMixin):
return polymatrix.expression.impl.EyeExprImpl(
variable=variable,
)
-def init_from_symmetric_matrix_expr(
- underlying: ExpressionBaseMixin,
-):
+def init_from_symmetric_matrix_expr(underlying: ExpressionBaseMixin):
return polymatrix.expression.impl.FromSymmetricMatrixExprImpl(
underlying=underlying,
)
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
index 45e5e1b..248f5b7 100644
--- a/polymatrix/expression/mixins/blockdiagexprmixin.py
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -1,79 +1,41 @@
-import abc
-import itertools
-import dataclassabc
+from abc import abstractmethod
+from typing_extensions import override
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.index import PolyDict
+from polymatrix.polymatrix.mixins import BlockPolyMatrixMixin
+from polymatrix.polymatrix.init import init_block_poly_matrix
from polymatrix.expressionstate import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
class BlockDiagExprMixin(ExpressionBaseMixin):
"""
- Create a block diagonal polymatrix from provided polymatrices
+ Create a block diagonal matrix from a tuple of matrices
+ ::
- [[x1]], [[x2], [x3]] -> [[x1, 0], [0, x2], [0, x3]].
+ (A, B, C) -> [ A ]
+ [ B ]
+ [ C ]
"""
@property
- @abc.abstractmethod
- def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ...
+ @abstractmethod
+ def blocks(self) -> tuple[ExpressionBaseMixin, ...]: ...
- # overwrites the abstract method of `ExpressionBaseMixin`
+ @override
def apply(
self,
state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]: # FIXME: incorrect return type
- all_underlying = []
- for expr in self.underlying:
- state, polymat = expr.apply(state=state)
- all_underlying.append(polymat)
+ ) -> tuple[ExpressionState, BlockPolyMatrixMixin]:
- # FIXME: move to polymatrix module
- @dataclassabc.dataclassabc(frozen=True)
- class BlockDiagPolyMatrix(PolyMatrixMixin):
- all_underlying: tuple[PolyMatrixMixin]
- underlying_row_col_range: tuple[tuple[int, int], ...]
- shape: tuple[int, int]
+ d, row, col = {}, 0, 0
+ for block in self.blocks:
+ state, pm = block.apply(state)
+ block_nrows, block_ncols = pm.shape
- def at(self, row: int, col: int) -> PolyDict:
- return self.get_poly(row, col) or PolyDict.empty()
+ d[range(row, row + block_nrows),
+ range(col, col + block_ncols)] = pm
- # FIXME: typing problems
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- for polymatrix, ((row_start, col_start), (row_end, col_end)) in zip(
- self.all_underlying, self.underlying_row_col_range
- ):
- if row_start <= row < row_end:
- if col_start <= col < col_end:
- return polymatrix.get_poly(
- row=row - row_start,
- col=col - col_start,
- )
+ row += block_nrows
+ col += block_ncols
- else:
- return None
-
- # NP: Do not raise generic expression, specialize error
- raise Exception(f"row {row} is out of bounds")
-
- underlying_row_col_range = tuple(
- itertools.pairwise(
- itertools.accumulate(
- (expr.shape for expr in all_underlying),
- lambda acc, v: tuple(v1 + v2 for v1, v2 in zip(acc, v)),
- initial=(0, 0),
- )
- )
- )
-
- shape = underlying_row_col_range[-1][1]
-
- polymatrix = BlockDiagPolyMatrix(
- all_underlying=all_underlying,
- shape=shape,
- underlying_row_col_range=underlying_row_col_range,
- )
-
- return state, polymatrix
+ return state, init_block_poly_matrix(d)