diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/mixins/elemmultexprmixin.py | 18 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 5 |
2 files changed, 4 insertions, 19 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index ce5e97c..0d89361 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -3,7 +3,7 @@ import itertools import typing import dataclassabc -from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.init import init_poly_matrix, init_broadcast_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix @@ -31,21 +31,7 @@ class ElemMultExprMixin(ExpressionBaseMixin): if right.shape == (1, 1): right_poly = right.get_poly(0, 0) - - @dataclassabc.dataclassabc(frozen=True) - class BroadCastedPolyMatrix(PolyMatrixMixin): - underlying: tuple[tuple[int], float] - shape: tuple[int, int] - - def get_poly( - self, row: int, col: int - ) -> typing.Optional[dict[tuple[int, ...], float]]: - return self.underlying - - right = BroadCastedPolyMatrix( - underlying=right_poly, - shape=left.shape, - ) + right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape) poly_matrix_data = {} diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 704ebe5..6d9adaf 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -1,3 +1,4 @@ +from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict @@ -13,12 +14,10 @@ def init_poly_matrix( ) -# NP: cosider renaming to scalar cf. comment on broadcast_poly_matrix and BroadcastPolymatrix -# FIXME: use polymatrix.typing def init_broadcast_poly_matrix( data: PolyDict, shape: tuple[int, int], -): +) -> BroadcastPolyMatrixMixin: return BroadcastPolyMatrixImpl( data=data, shape=shape, |