summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py18
-rw-r--r--polymatrix/polymatrix/init.py5
2 files changed, 4 insertions, 19 deletions
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index ce5e97c..0d89361 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -3,7 +3,7 @@ import itertools
import typing
import dataclassabc
-from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.init import init_poly_matrix, init_broadcast_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
@@ -31,21 +31,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
if right.shape == (1, 1):
right_poly = right.get_poly(0, 0)
-
- @dataclassabc.dataclassabc(frozen=True)
- class BroadCastedPolyMatrix(PolyMatrixMixin):
- underlying: tuple[tuple[int], float]
- shape: tuple[int, int]
-
- def get_poly(
- self, row: int, col: int
- ) -> typing.Optional[dict[tuple[int, ...], float]]:
- return self.underlying
-
- right = BroadCastedPolyMatrix(
- underlying=right_poly,
- shape=left.shape,
- )
+ right = init_broadcast_poly_matrix(data=right_poly, shape=right.shape)
poly_matrix_data = {}
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 704ebe5..6d9adaf 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -1,3 +1,4 @@
+from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin
from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl
from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict
@@ -13,12 +14,10 @@ def init_poly_matrix(
)
-# NP: cosider renaming to scalar cf. comment on broadcast_poly_matrix and BroadcastPolymatrix
-# FIXME: use polymatrix.typing
def init_broadcast_poly_matrix(
data: PolyDict,
shape: tuple[int, int],
-):
+) -> BroadcastPolyMatrixMixin:
return BroadcastPolyMatrixImpl(
data=data,
shape=shape,