summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/expression.py7
-rw-r--r--polymatrix/expression/impl.py38
-rw-r--r--polymatrix/expression/init.py28
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py99
-rw-r--r--polymatrix/expression/mixins/sliceexprmixin.py55
-rw-r--r--polymatrix/polymatrix/impl.py1
6 files changed, 87 insertions, 141 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 94492ef..494a245 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -57,12 +57,11 @@ class Expression(ExpressionBaseMixin, ABC):
else:
return attr
- def __getitem__(self, key: tuple[int, int]) -> Expression:
- # FIXME: typing for key is incorrect, could be a slice
+ def __getitem__(self, slice: tuple[int | slice, int | slice]) -> Expression:
return self.copy(
- underlying=polymatrix.expression.init.init_get_item_expr(
+ underlying=polymatrix.expression.init.init_slice_expr(
underlying=self.underlying,
- index=key,
+ slice=slice
),
)
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index e4cae1b..2eeea39 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -5,57 +5,59 @@ from polymatrix.statemonad import StateMonad
import dataclassabc
-from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin
-from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin
-from polymatrix.expression.mixins.productexprmixin import ProductExprMixin
from polymatrix.utils.getstacklines import FrameSummary
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin
from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin
from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin
+from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin
from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin
from polymatrix.expression.mixins.divergenceexprmixin import DivergenceExprMixin
from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin
from polymatrix.expression.mixins.evalexprmixin import EvalExprMixin
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin
from polymatrix.expression.mixins.filterexprmixin import FilterExprMixin
-from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin
from polymatrix.expression.mixins.fromnumbersexprmixin import FromNumbersExprMixin
from polymatrix.expression.mixins.fromnumpyexprmixin import FromNumpyExprMixin
from polymatrix.expression.mixins.fromstatemonad import FromStateMonadMixin
+from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin
from polymatrix.expression.mixins.fromsympyexprmixin import FromSympyExprMixin
-from polymatrix.expression.mixins.fromtermsexprmixin import (
- FromPolynomialDataExprMixin,
- PolynomialMatrixTupledData,
-)
-from polymatrix.expression.mixins.getitemexprmixin import GetItemExprMixin
from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import HalfNewtonPolytopeExprMixin
+from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin
+from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin
from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin
from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
-from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin
+from polymatrix.expression.mixins.productexprmixin import ProductExprMixin
from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin
from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMonomialsExprMixin
from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin
from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin
from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin
+from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin
from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin
from polymatrix.expression.mixins.sumexprmixin import SumExprMixin
from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin
from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin
+from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin
from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricMatrixExprMixin
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
-from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin
+
+from polymatrix.expression.mixins.fromtermsexprmixin import (
+ FromPolynomialDataExprMixin,
+ PolynomialMatrixTupledData,
+)
@dataclassabc.dataclassabc(frozen=True)
@@ -191,12 +193,6 @@ class FromPolynomialDataExprImpl(FromPolynomialDataExprMixin):
@dataclassabc.dataclassabc(frozen=True)
-class GetItemExprImpl(GetItemExprMixin):
- underlying: ExpressionBaseMixin
- index: tuple[tuple[int, ...], tuple[int, ...]]
-
-
-@dataclassabc.dataclassabc(frozen=True)
class HalfNewtonPolytopeExprImpl(HalfNewtonPolytopeExprMixin):
monomials: ExpressionBaseMixin
variables: ExpressionBaseMixin
@@ -356,6 +352,12 @@ class SetElementAtExprImpl(SetElementAtExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class SliceExprImpl(SliceExprMixin):
+ underlying: ExpressionBaseMixin
+ slice: tuple[int | slice | range, int | slice | range]
+
+
+@dataclassabc.dataclassabc(frozen=True)
class SqueezeExprImpl(SqueezeExprMixin):
underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 5e72524..30fd3e1 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -253,26 +253,6 @@ def init_from_terms_expr(
)
-def init_get_item_expr(
- underlying: ExpressionBaseMixin,
- index: tuple[tuple[int, ...], tuple[int, ...]],
-):
- def get_hashable_slice(index):
- if isinstance(index, slice):
- return polymatrix.expression.impl.GetItemExprImpl.Slice(
- start=index.start, stop=index.stop, step=index.step
- )
- else:
- return index
-
- proper_index = (get_hashable_slice(index[0]), get_hashable_slice(index[1]))
-
- return polymatrix.expression.impl.GetItemExprImpl(
- underlying=underlying,
- index=proper_index,
- )
-
-
def init_half_newton_polytope_expr(
monomials: ExpressionBaseMixin,
variables: ExpressionBaseMixin,
@@ -396,6 +376,14 @@ def init_set_element_at_expr(
)
+def init_slice_expr(
+ underlying: ExpressionBaseMixin,
+ slice: tuple[int | slice | range, int | slice | range]
+):
+ return polymatrix.expression.impl.SliceExprImpl(underlying, slice)
+
+
+
def init_squeeze_expr(
underlying: ExpressionBaseMixin,
):
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
deleted file mode 100644
index 8e95051..0000000
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import abc
-import dataclasses
-import dataclassabc
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
-
-from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.index import PolyDict
-from polymatrix.expressionstate import ExpressionState
-
-
-class GetItemExprMixin(ExpressionBaseMixin):
- @dataclasses.dataclass(frozen=True)
- class Slice:
- start: int
- step: int
- stop: int
-
- @property
- @abc.abstractmethod
- def underlying(self) -> ExpressionBaseMixin: ...
-
- @property
- @abc.abstractmethod
- def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]: ...
-
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
- state, underlying = self.underlying.apply(state=state)
-
- def get_proper_index(index, shape):
- if isinstance(index, tuple):
- return index
-
- elif isinstance(index, GetItemExprMixin.Slice):
- if index.start is None:
- start = 0
- else:
- start = index.start
-
- if index.stop is None:
- stop = shape
- else:
- stop = index.stop
-
- if index.step is None:
- step = 1
- else:
- step = index.step
-
- return tuple(range(start, stop, step))
-
- else:
- return (index,)
-
- proper_index = (
- get_proper_index(self.index[0], underlying.shape[0]),
- get_proper_index(self.index[1], underlying.shape[1]),
- )
-
- # FIXME: move to polymatrix module
- @dataclassabc.dataclassabc(frozen=True)
- class GetItemPolyMatrix(PolyMatrixMixin):
- underlying: PolyMatrixMixin
- index: tuple[int, int]
-
- @property
- def shape(self) -> tuple[int, int]:
- return (len(self.index[0]), len(self.index[1]))
-
- def at(self, row: int, col: int) -> PolyDict:
- # FIXME: this is a quick fix
- return self.get_poly(row, col) or PolyDict.empty()
-
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- try:
- n_row = self.index[0][row]
- except IndexError:
- raise IndexError(
- f"tuple index {row} out of range given {self.index[0]}"
- )
-
- try:
- n_col = self.index[1][col]
- except IndexError:
- raise IndexError(
- f"tuple index {col} out of range given {self.index[1]}"
- )
-
- return self.underlying.get_poly(n_row, n_col)
-
- return state, GetItemPolyMatrix(
- underlying=underlying,
- index=proper_index,
- )
diff --git a/polymatrix/expression/mixins/sliceexprmixin.py b/polymatrix/expression/mixins/sliceexprmixin.py
new file mode 100644
index 0000000..f9350cd
--- /dev/null
+++ b/polymatrix/expression/mixins/sliceexprmixin.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.polymatrix.init import init_slice_poly_matrix
+
+
+class SliceExprMixin(ExpressionBaseMixin):
+ """
+ Take a slice (one or more elements) of a matrix.
+
+ Examples:
+ Suppose there is the following matrix
+
+ M =
+ [ a b c d ]
+ [ e f g h ]
+ [ i j k l ]
+ [ m n o p ]
+
+ By slicing M[0,3] we get element i.
+ By slicing M[:,3] we get the column [[c] [g] [k] [o]]
+ By slicing M[:2,:2] we get the submatrix
+
+ [ a b ]
+ [ e f ]
+
+ and so on.
+ """
+
+ @property
+ @abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ """ Expression to take the slice from. """
+
+ @property
+ @abstractmethod
+ def slice(self) -> tuple:
+ """ The slice. """
+ # Type / format of this property must match of slice accepted by
+ # SlicePolyMatrix, since it directly uses that see
+ # polymatrix.polymatrix.init.init_poly_matrix
+
+ # TODO: allow slice to be an Expression that evaluates to a number or
+ # vector of numbers
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ state, p = self.underlying.apply(state)
+ return state, init_slice_poly_matrix(p, self.slice)
+
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index 1d3e1b9..816e4e1 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -21,6 +21,7 @@ class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin):
class SlicePolyMatrixImpl(SlicePolyMatrixMixin):
reference: PolyMatrixMixin
shape: tuple[int, int]
+ slice: tuple[tuple[int, ...], tuple[int, ...]]
@dataclassabc.dataclassabc(frozen=True)