diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/init.py | 19 | ||||
-rw-r--r-- | polymatrix/expression/mixins/sliceexprmixin.py | 12 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 5 |
3 files changed, 32 insertions, 4 deletions
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index f0c2a0b..6f75ab3 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -394,9 +394,22 @@ def init_slice_expr( underlying: ExpressionBaseMixin, slices: int | slice | range | tuple[int | slice | range, int | slice | range] ): - # FIXME: For some reason in older python versions slice is not a hashable - # type and this causes crashes when old code uses the cache in the state object - # cf. https://stackoverflow.com/questions/29980786/why-are-slice-objects-not-hashable-in-python + # FIXME: see comment above this HashableSlice class + HashableSlice = polymatrix.expression.mixins.sliceexprmixin.HashableSlice + + if isinstance(slices, slice): + slices = HashableSlice(slices.start, slices.stop, slices.step) + + elif isinstance(slices, tuple): + new_slices = list(slices) + if isinstance(slices[0], slice): + new_slices[0] = HashableSlice(slices[0].start, slices[0].stop, slices[0].step) + + if isinstance(slices[1], slice): + new_slices[1] = HashableSlice(slices[1].start, slices[1].stop, slices[1].step) + + slices = tuple(new_slices) + return polymatrix.expression.impl.SliceExprImpl(underlying, slices) diff --git a/polymatrix/expression/mixins/sliceexprmixin.py b/polymatrix/expression/mixins/sliceexprmixin.py index f9350cd..ce87c21 100644 --- a/polymatrix/expression/mixins/sliceexprmixin.py +++ b/polymatrix/expression/mixins/sliceexprmixin.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from typing import NamedTuple from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -9,6 +10,17 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.init import init_slice_poly_matrix +# For some reason in older python versions slice is not a hashable +# type and this causes crashes when old code uses the cache in the state object +# cf. https://stackoverflow.com/questions/29980786/why-are-slice-objects-not-hashable-in-python +# FIXME: Move this class to the correct place, eg. in global typing module or +# find a solution +class HashableSlice(NamedTuple): + start: int | None + stop: int | None + step: int | None + + class SliceExprMixin(ExpressionBaseMixin): """ Take a slice (one or more elements) of a matrix. diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 8b390c4..4326dd8 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -90,6 +90,9 @@ def init_slice_poly_matrix( formatted_slices: list[tuple] = [(), ()] shape = [0, 0] + # FIXME: see comment on HashableSlice class + from polymatrix.expression.mixins.sliceexprmixin import HashableSlice + for i, (what, el, numel) in enumerate(zip(("Row", "Column"), slices, reference.shape)): if isinstance(el, int): if not (0 <= el < numel): @@ -99,7 +102,7 @@ def init_slice_poly_matrix( formatted_slices[i] = (el,) shape[i] = 1 - elif isinstance(el, slice): + elif isinstance(el, slice | HashableSlice): # convert to range el = range(el.start or 0, el.stop or numel, el.step or 1) |