diff options
-rw-r--r-- | polymatrix/expression/expression.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 38 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 28 | ||||
-rw-r--r-- | polymatrix/expression/mixins/getitemexprmixin.py | 99 | ||||
-rw-r--r-- | polymatrix/expression/mixins/sliceexprmixin.py | 55 | ||||
-rw-r--r-- | polymatrix/polymatrix/impl.py | 1 |
6 files changed, 87 insertions, 141 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 94492ef..494a245 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -57,12 +57,11 @@ class Expression(ExpressionBaseMixin, ABC): else: return attr - def __getitem__(self, key: tuple[int, int]) -> Expression: - # FIXME: typing for key is incorrect, could be a slice + def __getitem__(self, slice: tuple[int | slice, int | slice]) -> Expression: return self.copy( - underlying=polymatrix.expression.init.init_get_item_expr( + underlying=polymatrix.expression.init.init_slice_expr( underlying=self.underlying, - index=key, + slice=slice ), ) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index e4cae1b..2eeea39 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -5,57 +5,59 @@ from polymatrix.statemonad import StateMonad import dataclassabc -from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin -from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin -from polymatrix.expression.mixins.productexprmixin import ProductExprMixin from polymatrix.utils.getstacklines import FrameSummary -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin +from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin from polymatrix.expression.mixins.divergenceexprmixin import DivergenceExprMixin from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin from polymatrix.expression.mixins.evalexprmixin import EvalExprMixin +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin from polymatrix.expression.mixins.filterexprmixin import FilterExprMixin -from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin from polymatrix.expression.mixins.fromnumbersexprmixin import FromNumbersExprMixin from polymatrix.expression.mixins.fromnumpyexprmixin import FromNumpyExprMixin from polymatrix.expression.mixins.fromstatemonad import FromStateMonadMixin +from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import FromSymmetricMatrixExprMixin from polymatrix.expression.mixins.fromsympyexprmixin import FromSympyExprMixin -from polymatrix.expression.mixins.fromtermsexprmixin import ( - FromPolynomialDataExprMixin, - PolynomialMatrixTupledData, -) -from polymatrix.expression.mixins.getitemexprmixin import GetItemExprMixin from polymatrix.expression.mixins.halfnewtonpolytopeexprmixin import HalfNewtonPolytopeExprMixin +from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin +from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomialsExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin -from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin +from polymatrix.expression.mixins.productexprmixin import ProductExprMixin from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMonomialsExprMixin from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin +from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin from polymatrix.expression.mixins.sumexprmixin import SumExprMixin from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin +from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricMatrixExprMixin from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin -from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariablesExprMixin + +from polymatrix.expression.mixins.fromtermsexprmixin import ( + FromPolynomialDataExprMixin, + PolynomialMatrixTupledData, +) @dataclassabc.dataclassabc(frozen=True) @@ -191,12 +193,6 @@ class FromPolynomialDataExprImpl(FromPolynomialDataExprMixin): @dataclassabc.dataclassabc(frozen=True) -class GetItemExprImpl(GetItemExprMixin): - underlying: ExpressionBaseMixin - index: tuple[tuple[int, ...], tuple[int, ...]] - - -@dataclassabc.dataclassabc(frozen=True) class HalfNewtonPolytopeExprImpl(HalfNewtonPolytopeExprMixin): monomials: ExpressionBaseMixin variables: ExpressionBaseMixin @@ -356,6 +352,12 @@ class SetElementAtExprImpl(SetElementAtExprMixin): @dataclassabc.dataclassabc(frozen=True) +class SliceExprImpl(SliceExprMixin): + underlying: ExpressionBaseMixin + slice: tuple[int | slice | range, int | slice | range] + + +@dataclassabc.dataclassabc(frozen=True) class SqueezeExprImpl(SqueezeExprMixin): underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 5e72524..30fd3e1 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -253,26 +253,6 @@ def init_from_terms_expr( ) -def init_get_item_expr( - underlying: ExpressionBaseMixin, - index: tuple[tuple[int, ...], tuple[int, ...]], -): - def get_hashable_slice(index): - if isinstance(index, slice): - return polymatrix.expression.impl.GetItemExprImpl.Slice( - start=index.start, stop=index.stop, step=index.step - ) - else: - return index - - proper_index = (get_hashable_slice(index[0]), get_hashable_slice(index[1])) - - return polymatrix.expression.impl.GetItemExprImpl( - underlying=underlying, - index=proper_index, - ) - - def init_half_newton_polytope_expr( monomials: ExpressionBaseMixin, variables: ExpressionBaseMixin, @@ -396,6 +376,14 @@ def init_set_element_at_expr( ) +def init_slice_expr( + underlying: ExpressionBaseMixin, + slice: tuple[int | slice | range, int | slice | range] +): + return polymatrix.expression.impl.SliceExprImpl(underlying, slice) + + + def init_squeeze_expr( underlying: ExpressionBaseMixin, ): diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py deleted file mode 100644 index 8e95051..0000000 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ /dev/null @@ -1,99 +0,0 @@ -import abc -import dataclasses -import dataclassabc -from polymatrix.polymatrix.mixins import PolyMatrixMixin - -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.polymatrix.index import PolyDict -from polymatrix.expressionstate import ExpressionState - - -class GetItemExprMixin(ExpressionBaseMixin): - @dataclasses.dataclass(frozen=True) - class Slice: - start: int - step: int - stop: int - - @property - @abc.abstractmethod - def underlying(self) -> ExpressionBaseMixin: ... - - @property - @abc.abstractmethod - def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]: ... - - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - state, underlying = self.underlying.apply(state=state) - - def get_proper_index(index, shape): - if isinstance(index, tuple): - return index - - elif isinstance(index, GetItemExprMixin.Slice): - if index.start is None: - start = 0 - else: - start = index.start - - if index.stop is None: - stop = shape - else: - stop = index.stop - - if index.step is None: - step = 1 - else: - step = index.step - - return tuple(range(start, stop, step)) - - else: - return (index,) - - proper_index = ( - get_proper_index(self.index[0], underlying.shape[0]), - get_proper_index(self.index[1], underlying.shape[1]), - ) - - # FIXME: move to polymatrix module - @dataclassabc.dataclassabc(frozen=True) - class GetItemPolyMatrix(PolyMatrixMixin): - underlying: PolyMatrixMixin - index: tuple[int, int] - - @property - def shape(self) -> tuple[int, int]: - return (len(self.index[0]), len(self.index[1])) - - def at(self, row: int, col: int) -> PolyDict: - # FIXME: this is a quick fix - return self.get_poly(row, col) or PolyDict.empty() - - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - try: - n_row = self.index[0][row] - except IndexError: - raise IndexError( - f"tuple index {row} out of range given {self.index[0]}" - ) - - try: - n_col = self.index[1][col] - except IndexError: - raise IndexError( - f"tuple index {col} out of range given {self.index[1]}" - ) - - return self.underlying.get_poly(n_row, n_col) - - return state, GetItemPolyMatrix( - underlying=underlying, - index=proper_index, - ) diff --git a/polymatrix/expression/mixins/sliceexprmixin.py b/polymatrix/expression/mixins/sliceexprmixin.py new file mode 100644 index 0000000..f9350cd --- /dev/null +++ b/polymatrix/expression/mixins/sliceexprmixin.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.init import init_slice_poly_matrix + + +class SliceExprMixin(ExpressionBaseMixin): + """ + Take a slice (one or more elements) of a matrix. + + Examples: + Suppose there is the following matrix + + M = + [ a b c d ] + [ e f g h ] + [ i j k l ] + [ m n o p ] + + By slicing M[0,3] we get element i. + By slicing M[:,3] we get the column [[c] [g] [k] [o]] + By slicing M[:2,:2] we get the submatrix + + [ a b ] + [ e f ] + + and so on. + """ + + @property + @abstractmethod + def underlying(self) -> ExpressionBaseMixin: + """ Expression to take the slice from. """ + + @property + @abstractmethod + def slice(self) -> tuple: + """ The slice. """ + # Type / format of this property must match of slice accepted by + # SlicePolyMatrix, since it directly uses that see + # polymatrix.polymatrix.init.init_poly_matrix + + # TODO: allow slice to be an Expression that evaluates to a number or + # vector of numbers + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + state, p = self.underlying.apply(state) + return state, init_slice_poly_matrix(p, self.slice) + diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py index 1d3e1b9..816e4e1 100644 --- a/polymatrix/polymatrix/impl.py +++ b/polymatrix/polymatrix/impl.py @@ -21,6 +21,7 @@ class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin): class SlicePolyMatrixImpl(SlicePolyMatrixMixin): reference: PolyMatrixMixin shape: tuple[int, int] + slice: tuple[tuple[int, ...], tuple[int, ...]] @dataclassabc.dataclassabc(frozen=True) |