summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-25 10:37:51 +0200
committerNao Pross <np@0hm.ch>2024-05-25 10:37:51 +0200
commit8371fac2d550d96d6f9d4f479cb1f203eaff452d (patch)
treec2c0802d003d5d8cd7ba2059dbeb8567233d7554
parentMake Expression.shape a property (diff)
downloadpolymatrix-8371fac2d550d96d6f9d4f479cb1f203eaff452d.tar.gz
polymatrix-8371fac2d550d96d6f9d4f479cb1f203eaff452d.zip
Improve slicing, allow M[i] as shorthand for M[i, :]
-rw-r--r--polymatrix/expression/expression.py2
-rw-r--r--polymatrix/expression/impl.py2
-rw-r--r--polymatrix/polymatrix/init.py8
3 files changed, 8 insertions, 4 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index d697aed..db39780 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -57,7 +57,7 @@ class Expression(ExpressionBaseMixin, ABC):
else:
return attr
- def __getitem__(self, slice: tuple[int | slice, int | slice]) -> Expression:
+ def __getitem__(self, slice: int | slice | tuple[int | slice, int | slice]) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_slice_expr(
underlying=self.underlying,
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 177d44e..9451ba6 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -375,7 +375,7 @@ class ShapeExprImpl(ShapeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class SliceExprImpl(SliceExprMixin):
underlying: ExpressionBaseMixin
- slice: tuple[int | slice | range, int | slice | range]
+ slice: tuple # See SlicePolyMatrixMixin for details of this tuple
@dataclassabc.dataclassabc(frozen=True)
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 4c89c29..b8af9f7 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import math
import numpy as np
-from typing import TYPE_CHECKING
+from typing import Iterable, TYPE_CHECKING
from polymatrix.polymatrix.impl import (
BroadcastPolyMatrixImpl,
@@ -87,9 +87,13 @@ def init_block_poly_matrix(blocks: dict[tuple[range, range], PolyMatrixMixin]) -
def init_slice_poly_matrix(
reference: PolyMatrixMixin,
- slices: tuple[int | Iterable[int], int | Iterable[int]]
+ slices: int | slice | range | tuple[int | Iterable[int], int | Iterable[int]]
) -> SlicePolyMatrixMixin:
+ # For example v[0] it is implicitly converted to v[0, :]
+ if isinstance(slices, int | slice | range):
+ slices = (slices, slice(None, None, None))
+
formatted_slices: list[tuple] = [(), ()]
shape = [0, 0]