diff options
-rw-r--r-- | polymatrix/expression/expression.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 2 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 8 |
3 files changed, 8 insertions, 4 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index d697aed..db39780 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -57,7 +57,7 @@ class Expression(ExpressionBaseMixin, ABC): else: return attr - def __getitem__(self, slice: tuple[int | slice, int | slice]) -> Expression: + def __getitem__(self, slice: int | slice | tuple[int | slice, int | slice]) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_slice_expr( underlying=self.underlying, diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 177d44e..9451ba6 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -375,7 +375,7 @@ class ShapeExprImpl(ShapeExprMixin): @dataclassabc.dataclassabc(frozen=True) class SliceExprImpl(SliceExprMixin): underlying: ExpressionBaseMixin - slice: tuple[int | slice | range, int | slice | range] + slice: tuple # See SlicePolyMatrixMixin for details of this tuple @dataclassabc.dataclassabc(frozen=True) diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 4c89c29..b8af9f7 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -3,7 +3,7 @@ from __future__ import annotations import math import numpy as np -from typing import TYPE_CHECKING +from typing import Iterable, TYPE_CHECKING from polymatrix.polymatrix.impl import ( BroadcastPolyMatrixImpl, @@ -87,9 +87,13 @@ def init_block_poly_matrix(blocks: dict[tuple[range, range], PolyMatrixMixin]) - def init_slice_poly_matrix( reference: PolyMatrixMixin, - slices: tuple[int | Iterable[int], int | Iterable[int]] + slices: int | slice | range | tuple[int | Iterable[int], int | Iterable[int]] ) -> SlicePolyMatrixMixin: + # For example v[0] it is implicitly converted to v[0, :] + if isinstance(slices, int | slice | range): + slices = (slices, slice(None, None, None)) + formatted_slices: list[tuple] = [(), ()] shape = [0, 0] |