summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/init.py19
-rw-r--r--polymatrix/expression/mixins/sliceexprmixin.py12
-rw-r--r--polymatrix/polymatrix/init.py5
3 files changed, 32 insertions, 4 deletions
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index f0c2a0b..6f75ab3 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -394,9 +394,22 @@ def init_slice_expr(
underlying: ExpressionBaseMixin,
slices: int | slice | range | tuple[int | slice | range, int | slice | range]
):
- # FIXME: For some reason in older python versions slice is not a hashable
- # type and this causes crashes when old code uses the cache in the state object
- # cf. https://stackoverflow.com/questions/29980786/why-are-slice-objects-not-hashable-in-python
+ # FIXME: see comment above this HashableSlice class
+ HashableSlice = polymatrix.expression.mixins.sliceexprmixin.HashableSlice
+
+ if isinstance(slices, slice):
+ slices = HashableSlice(slices.start, slices.stop, slices.step)
+
+ elif isinstance(slices, tuple):
+ new_slices = list(slices)
+ if isinstance(slices[0], slice):
+ new_slices[0] = HashableSlice(slices[0].start, slices[0].stop, slices[0].step)
+
+ if isinstance(slices[1], slice):
+ new_slices[1] = HashableSlice(slices[1].start, slices[1].stop, slices[1].step)
+
+ slices = tuple(new_slices)
+
return polymatrix.expression.impl.SliceExprImpl(underlying, slices)
diff --git a/polymatrix/expression/mixins/sliceexprmixin.py b/polymatrix/expression/mixins/sliceexprmixin.py
index f9350cd..ce87c21 100644
--- a/polymatrix/expression/mixins/sliceexprmixin.py
+++ b/polymatrix/expression/mixins/sliceexprmixin.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import abstractmethod
+from typing import NamedTuple
from typing_extensions import override
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -9,6 +10,17 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.init import init_slice_poly_matrix
+# For some reason in older python versions slice is not a hashable
+# type and this causes crashes when old code uses the cache in the state object
+# cf. https://stackoverflow.com/questions/29980786/why-are-slice-objects-not-hashable-in-python
+# FIXME: Move this class to the correct place, eg. in global typing module or
+# find a solution
+class HashableSlice(NamedTuple):
+ start: int | None
+ stop: int | None
+ step: int | None
+
+
class SliceExprMixin(ExpressionBaseMixin):
"""
Take a slice (one or more elements) of a matrix.
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 8b390c4..4326dd8 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -90,6 +90,9 @@ def init_slice_poly_matrix(
formatted_slices: list[tuple] = [(), ()]
shape = [0, 0]
+ # FIXME: see comment on HashableSlice class
+ from polymatrix.expression.mixins.sliceexprmixin import HashableSlice
+
for i, (what, el, numel) in enumerate(zip(("Row", "Column"), slices, reference.shape)):
if isinstance(el, int):
if not (0 <= el < numel):
@@ -99,7 +102,7 @@ def init_slice_poly_matrix(
formatted_slices[i] = (el,)
shape[i] = 1
- elif isinstance(el, slice):
+ elif isinstance(el, slice | HashableSlice):
# convert to range
el = range(el.start or 0, el.stop or numel, el.step or 1)