From c48eac2f053322c62ed7deaa709bc05984dbf9a0 Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Mon, 26 Feb 2024 08:53:32 +0100 Subject: clean ups --- polymatrix/__init__.py | 23 ++-- polymatrix/denserepr/from_.py | 6 +- polymatrix/expression/expression.py | 139 +++++++-------------- polymatrix/expression/from_.py | 33 +---- polymatrix/expression/impl.py | 21 +++- polymatrix/expression/init.py | 99 ++------------- polymatrix/expression/mixins/additionexprmixin.py | 79 +++++++----- polymatrix/expression/mixins/blockdiagexprmixin.py | 2 +- polymatrix/expression/mixins/cacheexprmixin.py | 2 +- .../expression/mixins/combinationsexprmixin.py | 33 +++-- polymatrix/expression/mixins/degreeexprmixin.py | 52 ++++++++ .../expression/mixins/derivativeexprmixin.py | 2 +- .../expression/mixins/determinantexprmixin.py | 4 +- polymatrix/expression/mixins/diagexprmixin.py | 2 +- .../expression/mixins/divergenceexprmixin.py | 2 +- polymatrix/expression/mixins/divisionexprmixin.py | 2 +- polymatrix/expression/mixins/elemmultexprmixin.py | 4 +- polymatrix/expression/mixins/evalexprmixin.py | 2 +- polymatrix/expression/mixins/eyeexprmixin.py | 2 +- polymatrix/expression/mixins/filterexprmixin.py | 5 +- .../expression/mixins/filterlinearpartexprmixin.py | 2 +- .../mixins/fromsymmetricmatrixexprmixin.py | 2 +- polymatrix/expression/mixins/fromtermsexprmixin.py | 2 +- polymatrix/expression/mixins/fromtupleexprmixin.py | 5 +- polymatrix/expression/mixins/getitemexprmixin.py | 2 +- .../mixins/halfnewtonpolytopeexprmixin.py | 2 +- .../expression/mixins/legendreseriesmixin.py | 2 +- polymatrix/expression/mixins/linearinexprmixin.py | 2 +- .../expression/mixins/linearmatrixinexprmixin.py | 2 +- .../expression/mixins/linearmonomialsexprmixin.py | 2 +- .../expression/mixins/matrixmultexprmixin.py | 4 +- polymatrix/expression/mixins/maxdegreeexprmixin.py | 43 ------- polymatrix/expression/mixins/maxexprmixin.py | 2 +- .../expression/mixins/parametrizeexprmixin.py | 2 +- .../mixins/parametrizematrixexprmixin.py | 2 +- polymatrix/expression/mixins/productexprmixin.py | 36 +++--- .../expression/mixins/quadraticinexprmixin.py | 2 +- .../mixins/quadraticmonomialsexprmixin.py | 2 +- polymatrix/expression/mixins/repmatexprmixin.py | 2 +- polymatrix/expression/mixins/reshapeexprmixin.py | 2 +- .../expression/mixins/setelementatexprmixin.py | 2 +- polymatrix/expression/mixins/squeezeexprmixin.py | 2 +- .../expression/mixins/substituteexprmixin.py | 2 +- .../mixins/subtractmonomialsexprmixin.py | 2 +- polymatrix/expression/mixins/sumexprmixin.py | 2 +- polymatrix/expression/mixins/symmetricexprmixin.py | 2 +- .../expression/mixins/toconstantexprmixin.py | 2 +- .../expression/mixins/toquadraticexprmixin.py | 2 +- .../expression/mixins/tosortedvariablesmixin.py | 2 +- .../mixins/tosymmetricmatrixexprmixin.py | 2 +- polymatrix/expression/mixins/transposeexprmixin.py | 2 +- polymatrix/expression/mixins/truncateexprmixin.py | 14 ++- polymatrix/expression/mixins/vstackexprmixin.py | 2 +- polymatrix/expression/op.py | 30 ++++- polymatrix/expression/to.py | 62 ++++----- polymatrix/expression/utils/multiplypolynomial.py | 18 +-- polymatrix/polymatrix/impl.py | 9 +- polymatrix/polymatrix/mixins.py | 20 ++- polymatrix/polymatrix/typing.py | 4 + polymatrix/statemonad/mixins.py | 3 + 60 files changed, 386 insertions(+), 432 deletions(-) create mode 100644 polymatrix/expression/mixins/degreeexprmixin.py delete mode 100644 polymatrix/expression/mixins/maxdegreeexprmixin.py create mode 100644 polymatrix/polymatrix/typing.py diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index e513300..6592238 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -6,11 +6,11 @@ from polymatrix.expression.from_ import from_ as internal_from from polymatrix.expression import v_stack as internal_v_stack from polymatrix.expression import h_stack as internal_h_stack from polymatrix.expression import product as internal_product -from polymatrix.denserepr.from_ import from_polymatrix_expr -from polymatrix.expression.to import shape as internal_shape -from polymatrix.expression.to import to_constant_repr as internal_to_constant_repr -from polymatrix.expression.to import to_degrees as internal_to_degrees -from polymatrix.expression.to import to_sympy_repr as internal_to_sympy_repr +from polymatrix.denserepr.from_ import from_polymatrix +# from polymatrix.expression.to import shape as internal_shape +from polymatrix.expression.to import to_constant as internal_to_constant +# from polymatrix.expression.to import to_degrees as internal_to_degrees +from polymatrix.expression.to import to_sympy as internal_to_sympy Expression = internal_Expression ExpressionState = internal_ExpressionState @@ -21,11 +21,14 @@ from_ = internal_from v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product -to_shape = internal_shape -to_constant_repr = internal_to_constant_repr -to_degrees = internal_to_degrees -to_sympy_repr = internal_to_sympy_repr -to_matrix_repr = from_polymatrix_expr +# to_shape = internal_shape +to_constant_repr = internal_to_constant +to_constant = internal_to_constant +# to_degrees = internal_to_degrees +to_sympy_repr = internal_to_sympy +to_sympy = internal_to_sympy +to_matrix_repr = from_polymatrix +to_dense = from_polymatrix # def from_sympy( # data: tuple[tuple[float]], diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index b61af86..a8540e6 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -1,8 +1,6 @@ -import dataclasses import itertools -import typing import numpy as np -import scipy.sparse + from polymatrix.denserepr.impl import DenseReprBufferImpl, DenseReprImpl from polymatrix.expression.expression import Expression @@ -14,7 +12,7 @@ from polymatrix.statemonad.mixins import StateMonadMixin from polymatrix.expression.utils.monomialtoindex import monomial_to_index -def from_polymatrix_expr( +def from_polymatrix( expressions: Expression | tuple[Expression], variables: Expression = None, sorted: bool = None, diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index dc2f144..f5318ae 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -9,7 +9,7 @@ import polymatrix.expression.init from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.op import diff, linear_in, linear_monomials, legendre +from polymatrix.expression.op import diff, linear_in, linear_monomials, legendre, filter_, degree from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState @@ -24,7 +24,7 @@ class Expression( def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `PolyMatrixExprBaseMixin` + # overwrites the abstract method of `PolyMatrixExprBaseMixin` def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: return self.underlying.apply(state) @@ -44,8 +44,7 @@ class Expression( return attr def __getitem__(self, key: tuple[int, int]): - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_get_item_expr( underlying=self.underlying, index=key, @@ -118,14 +117,6 @@ class Expression( underlying=op(left, right, stack), ) - # def _convert_to_expression(self, other): - # result = init_from_expr_or_none(other) - - # if result is None: - # return NotImplemented - - # return result - def cache(self) -> 'Expression': return self.copy( underlying=polymatrix.expression.init.init_cache_expr( @@ -143,18 +134,23 @@ class Expression( degrees=degrees, ), ) + + def degree(self) -> 'Expression': + return self.copy( + underlying=degree( + underlying=self.underlying, + ), + ) def determinant(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_determinant_expr( underlying=self.underlying, ), ) def diag(self): - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_diag_expr( underlying=self.underlying, ), @@ -165,8 +161,7 @@ class Expression( variables: 'Expression', introduce_derivatives: bool = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=diff( expression=self, variables=variables, @@ -178,8 +173,7 @@ class Expression( self, variables: tuple, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_divergence_expr( underlying=self.underlying, variables=variables, @@ -191,8 +185,7 @@ class Expression( variable: tuple, value: tuple[float, ...] = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_eval_expr( underlying=self.underlying, variables=variable, @@ -206,9 +199,8 @@ class Expression( predicator: 'Expression', inverse: bool = None, ) -> 'Expression': - return dataclasses.replace( - self, - underlying=polymatrix.expression.init.init_filter_expr( + return self.copy( + underlying=filter_( underlying=self.underlying, predicator=predicator, inverse=inverse, @@ -217,8 +209,7 @@ class Expression( # only applies to symmetric matrix def from_symmetric_matrix(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr( underlying=self.underlying, ), @@ -230,8 +221,7 @@ class Expression( variables: 'Expression', filter: 'Expression | None' = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_half_newton_polytope_expr( monomials=self.underlying, variables=variables, @@ -240,8 +230,7 @@ class Expression( ) def linear_matrix_in(self, variable: 'Expression') -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_linear_matrix_in_expr( underlying=self.underlying, variable=variable, @@ -253,8 +242,7 @@ class Expression( variables: 'Expression', ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=linear_monomials( expression=self.underlying, variables=variables, @@ -268,8 +256,7 @@ class Expression( ignore_unmatched: bool = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=linear_in( expression=self.underlying, monomials=monomials, @@ -282,8 +269,7 @@ class Expression( self, degrees: tuple[int, ...] = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=legendre( expression=self.underlying, degrees=degrees, @@ -291,47 +277,27 @@ class Expression( ) def max(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_max_expr( underlying=self.underlying, ), ) - def max_degree(self) -> 'Expression': - return dataclasses.replace( - self, - underlying=polymatrix.expression.init.init_max_degree_expr( - underlying=self.underlying, - ), - ) - def parametrize(self, name: str = None) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_parametrize_expr( underlying=self.underlying, name=name, ), ) - # def parametrize_matrix(self, name: str = None) -> 'ExpressionMixin': - # return dataclasses.replace( - # self, - # underlying=init_parametrize_matrix_expr( - # underlying=self.underlying, - # name=name, - # ), - # ) - def quadratic_in(self, variables: 'Expression', monomials: 'Expression' = None) -> 'Expression': if monomials is None: monomials = self.quadratic_monomials(variables) stack = get_stack_lines() - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_quadratic_in_expr( underlying=self.underlying, monomials=monomials, @@ -344,8 +310,7 @@ class Expression( self, variables: 'Expression', ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_quadratic_monomials_expr( underlying=self.underlying, variables=variables, @@ -353,8 +318,7 @@ class Expression( ) def reshape(self, n: int, m: int) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_reshape_expr( underlying=self.underlying, new_shape=(n, m), @@ -362,8 +326,7 @@ class Expression( ) def rep_mat(self, n: int, m: int) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_rep_mat_expr( underlying=self.underlying, repetition=(n, m), @@ -381,8 +344,7 @@ class Expression( else: value = polymatrix.expression.init.init_from_expr(value) - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_set_element_at_expr( underlying=self.underlying, index=(row, col), @@ -394,8 +356,7 @@ class Expression( def squeeze( self, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_squeeze_expr( underlying=self.underlying, ), @@ -406,8 +367,7 @@ class Expression( self, monomials: 'Expression', ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_subtract_monomials_expr( underlying=self.underlying, monomials=monomials, @@ -419,8 +379,7 @@ class Expression( variable: tuple, values: tuple['Expression', ...] = None, ) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_substitute_expr( underlying=self.underlying, variables=variable, @@ -439,24 +398,21 @@ class Expression( ) def sum(self): - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_sum_expr( underlying=self.underlying, ), ) def symmetric(self): - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_symmetric_expr( underlying=self.underlying, ), ) def transpose(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_transpose_expr( underlying=self.underlying, ), @@ -467,24 +423,14 @@ class Expression( return self.transpose() def to_constant(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_to_constant_expr( underlying=self.underlying, ), ) - # def to_quadratic(self) -> 'ExpressionMixin': - # return dataclasses.replace( - # self, - # underlying=init_to_quadratic_expr( - # underlying=self.underlying, - # ), - # ) - def to_symmetric_matrix(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_to_symmetric_matrix_expr( underlying=self.underlying, ), @@ -492,8 +438,7 @@ class Expression( # only applies to variables def to_sorted_variables(self) -> 'Expression': - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_to_sorted_variables( underlying=self.underlying, ), @@ -502,12 +447,11 @@ class Expression( # also applies to monomials? def truncate( self, - variables: tuple, degrees: tuple[int], + variables: tuple | None = None, inverse: bool = None, ): - return dataclasses.replace( - self, + return self.copy( underlying=polymatrix.expression.init.init_truncate_expr( underlying=self.underlying, variables=variables, @@ -531,6 +475,7 @@ class ExpressionImpl(Expression): ) + def init_expression( underlying: ExpressionBaseMixin, ): diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index c290a67..daa916a 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -5,12 +5,13 @@ import polymatrix.expression.init from polymatrix.expression.expression import init_expression, Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.statemonad.abc import StateMonad -DATA_TYPE = str | np.ndarray | sympy.Matrix | sympy.Expr | tuple | ExpressionBaseMixin +FromDataTypes = str | np.ndarray | sympy.Matrix | sympy.Expr | tuple | ExpressionBaseMixin | StateMonad def from_expr_or_none( - data: DATA_TYPE, + data: FromDataTypes, ) -> Expression | None: return init_expression( @@ -20,7 +21,7 @@ def from_expr_or_none( ) def from_( - data: DATA_TYPE, + data: FromDataTypes, ) -> Expression: return init_expression( @@ -28,29 +29,3 @@ def from_( data=data, ), ) - -# def from_expr( -# data: DATA_TYPE, -# ) -> Expression: -# return from_(data=data) - -# def from_sympy( -# data: tuple[tuple[float]], -# ): -# return init_expression( -# polymatrix.expression.init.init_from_sympy_expr(data) -# ) - -# def from_state_monad( -# data: StateMonad, -# ): -# return init_expression( -# data.flat_map(lambda inner_data: polymatrix.expression.init.init_from_sympy_expr(inner_data)), -# ) - -# def from_polymatrix( -# polymatrix: PolyMatrix, -# ): -# return init_expression( -# polymatrix.expression.init.init_from_terms_expr(polymatrix) -# ) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 3e61941..937676d 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -1,8 +1,9 @@ +import typing +import dataclassabc + from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin from polymatrix.expression.mixins.productexprmixin import ProductExprMixin from polymatrix.utils.getstacklines import FrameSummary -import dataclassabc - from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin @@ -35,7 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import \ LinearMonomialsExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import \ MatrixMultExprMixin -from polymatrix.expression.mixins.maxdegreeexprmixin import MaxDegreeExprMixin +from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import \ ParametrizeExprMixin @@ -150,7 +151,7 @@ class EyeExprImpl(EyeExprMixin): class FilterExprImpl(FilterExprMixin): underlying: ExpressionBaseMixin predicator: ExpressionBaseMixin - inverse: bool == None + inverse: bool @dataclassabc.dataclassabc(frozen=True) @@ -211,6 +212,9 @@ class LegendreSeriesImpl(LegendreSeriesMixin): degrees: tuple[int, ...] | None stack: tuple[FrameSummary] + def __repr__(self): + return f'{self.__class__.__name__}(underlying={self.underlying}, degrees={self.degrees})' + @dataclassabc.dataclassabc(frozen=True) class MatrixMultExprImpl(MatrixMultExprMixin): @@ -223,8 +227,12 @@ class MatrixMultExprImpl(MatrixMultExprMixin): @dataclassabc.dataclassabc(frozen=True) -class MaxDegreeExprImpl(MaxDegreeExprMixin): +class DegreeExprImpl(DegreeExprMixin): underlying: ExpressionBaseMixin + stack: tuple[FrameSummary] + + def __repr__(self): + return f'{self.__class__.__name__}(underlying={self.underlying})' @dataclassabc.dataclassabc(frozen=True) @@ -253,6 +261,9 @@ class ProductExprImpl(ProductExprMixin): degrees: tuple[int, ...] | None stack: tuple[FrameSummary] + def __repr__(self): + return f'{self.__class__.__name__}(underlying={self.underlying}, degrees={self.degrees})' + @dataclassabc.dataclassabc(frozen=True) class QuadraticInExprImpl(QuadraticInExprMixin): diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index b7be83a..5e4eac1 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -4,11 +4,12 @@ import sympy import polymatrix.expression.impl +from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.statemonad.abc import StateMonad from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.getstacklines import get_stack_lines -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.utils.formatsubstitutions import format_substitutions -from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl import FromTupleExprImpl, AdditionExprImpl @@ -53,26 +54,6 @@ def init_combinations_expr( ) -# def init_derivative_expr( -# underlying: ExpressionBaseMixin, -# variables: ExpressionBaseMixin, -# stack: tuple[FrameSummary], -# introduce_derivatives: bool = None, -# ): - -# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' - -# if introduce_derivatives is None: -# introduce_derivatives = False - -# return polymatrix.expression.impl.DerivativeExprImpl( -# underlying=underlying, -# variables=variables, -# introduce_derivatives=introduce_derivatives, -# stack=stack, -# ) - - def init_determinant_expr( underlying: ExpressionBaseMixin, ): @@ -162,20 +143,6 @@ def init_eye_expr( ) -def init_filter_expr( - underlying: ExpressionBaseMixin, - predicator: ExpressionBaseMixin, - inverse: bool = None, -): - if inverse is None: - inverse = False - - return polymatrix.expression.impl.FilterExprImpl( - underlying=underlying, - predicator=predicator, - inverse=inverse, -) - def init_from_symmetric_matrix_expr( underlying: ExpressionBaseMixin, @@ -195,6 +162,9 @@ def init_from_expr_or_none( underlying=init_from_expr_or_none(1), name=data, ) + + elif isinstance(data, StateMonad): + return data.flat_map(lambda inner_data: init_from_expr_or_none(inner_data)) elif isinstance(data, np.ndarray): assert len(data.shape) <= 2 @@ -310,22 +280,6 @@ def init_half_newton_polytope_expr( ) -# def init_linear_in_expr( -# underlying: ExpressionBaseMixin, -# monomials: ExpressionBaseMixin, -# variables: ExpressionBaseMixin, -# ignore_unmatched: bool = None, -# ): -# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' - -# return polymatrix.expression.impl.LinearInExprImpl( -# underlying=underlying, -# monomials=monomials, -# variables=variables, -# ignore_unmatched = ignore_unmatched, -# ) - - def init_linear_matrix_in_expr( underlying: ExpressionBaseMixin, variable: int, @@ -336,19 +290,6 @@ def init_linear_matrix_in_expr( ) -# def init_linear_monomials_expr( -# underlying: ExpressionBaseMixin, -# variables: ExpressionBaseMixin, -# ): - -# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' - -# return polymatrix.expression.impl.LinearMonomialsExprImpl( -# underlying=underlying, -# variables=variables, -# ) - - def init_matrix_mult_expr( left: ExpressionBaseMixin, right: ExpressionBaseMixin, @@ -361,13 +302,6 @@ def init_matrix_mult_expr( ) -def init_max_degree_expr( - underlying: ExpressionBaseMixin, -): - return polymatrix.expression.impl.MaxDegreeExprImpl( - underlying=underlying, -) - def init_max_expr( underlying: ExpressionBaseMixin, @@ -569,9 +503,9 @@ def init_transpose_expr( def init_truncate_expr( underlying: ExpressionBaseMixin, - variables: ExpressionBaseMixin, degrees: tuple[int], - inverse: bool = None, + variables: ExpressionBaseMixin | None = None, + inverse: bool | None = None, ): if isinstance(degrees, int): degrees = (degrees,) @@ -585,20 +519,3 @@ def init_truncate_expr( degrees=degrees, inverse=inverse, ) - - -# def init_v_stack_expr( -# underlying: tuple, -# ): - -# def gen_underlying(): - -# for e in underlying: -# if isinstance(e, ExpressionBaseMixin): -# yield e -# else: -# yield init_from_(e) - -# return polymatrix.expression.impl.VStackExprImpl( -# underlying=tuple(gen_underlying()), -# ) diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 2cbbe1e..7dbf1d2 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,12 +1,10 @@ import abc import math -import typing -import dataclassabc +from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -28,7 +26,32 @@ class AdditionExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + @staticmethod + def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]): + # broadcast left + if left.shape == (1, 1) and right.shape != (1, 1): + left = BroadcastPolyMatrixImpl( + polynomial=left.get_poly(0, 0), + shape=right.shape, + ) + + # broadcast right + elif left.shape != (1, 1) and right.shape == (1, 1): + right = BroadcastPolyMatrixImpl( + polynomial=right.get_poly(0, 0), + shape=left.shape, + ) + + else: + if not (left.shape == right.shape): + raise AssertionError(to_operator_exception( + message=f'{left.shape} != {right.shape}', + stack=stack, + )) + + return left, right + + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -36,38 +59,37 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - if left.shape == (1, 1): - left, right = right, left + # if left.shape == (1, 1): + # left, right = right, left - if left.shape != (1, 1) and right.shape == (1, 1): + # if left.shape != (1, 1) and right.shape == (1, 1): - @dataclassabc.dataclassabc(frozen=True) - class BroadCastedPolyMatrix(PolyMatrixMixin): - underlying_monomials: tuple[tuple[int], float] - shape: tuple[int, int] + # # @dataclassabc.dataclassabc(frozen=True) + # # class BroadCastedPolyMatrix(PolyMatrixMixin): + # # underlying_monomials: tuple[tuple[int], float] + # # shape: tuple[int, int] - def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: - return self.underlying_monomials + # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + # # return self.underlying_monomials - polynomial = right.get_poly(0, 0) - if polynomial is not None: + # right = BroadcastPolyMatrixImpl( + # polynomial=right.get_poly(0, 0), + # shape=left.shape, + # ) - broadcasted_right = BroadCastedPolyMatrix( - underlying_monomials=polynomial, - shape=left.shape, - ) + # # all_underlying = (left, broadcasted_right) - all_underlying = (left, broadcasted_right) + # else: + # if not (left.shape == right.shape): + # raise AssertionError(to_operator_exception( + # message=f'{left.shape} != {right.shape}', + # stack=self.stack, + # )) - else: - if not (left.shape == right.shape): - raise AssertionError(to_operator_exception( - message=f'{left.shape} != {right.shape}', - stack=self.stack, - )) + # # all_underlying = (left, right) - all_underlying = (left, right) + left, right = self.broadcast(left, right, self.stack) terms = {} @@ -76,7 +98,7 @@ class AdditionExprMixin(ExpressionBaseMixin): terms_row_col = {} - for underlying in all_underlying: + for underlying in (left, right): polynomial = underlying.get_poly(row, col) if polynomial is None: @@ -90,6 +112,7 @@ class AdditionExprMixin(ExpressionBaseMixin): if monomial not in terms_row_col: terms_row_col[monomial] = value + else: terms_row_col[monomial] += value diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index 3ed5766..4754d53 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -15,7 +15,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin): def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index 937b0f3..e5aef11 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -17,7 +17,7 @@ class CacheExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 395f068..da669f4 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -1,6 +1,7 @@ import abc import itertools +from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -26,7 +27,7 @@ class CombinationsExprMixin(ExpressionBaseMixin): def degrees(self) -> tuple[int, ...]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -59,32 +60,28 @@ class CombinationsExprMixin(ExpressionBaseMixin): for row, indexing in enumerate(indices): - # print(indexing) - - if indexing is tuple(): + # x.combinations((0, 1, 2)) produces [1, x, x**2] + if len(indexing) == 0: terms[row, 0] = {tuple(): 1.0} continue - def acc_product(acc, v): - left_monomials = acc - row = v - - right_monomials = poly_matrix.get_poly(row, 0).keys() - # print(right_monomials) - - if left_monomials is (None,): - return right_monomials + def acc_product(left, row): + right = poly_matrix.get_poly(row, 0) - monomials = tuple(multiply_monomials(left_monomials, right_monomials)) - return monomials + if len(left) == 0: + return right + + result = {} + multiply_polynomial(left, right, result) + return result - *_, monomials = itertools.accumulate( + *_, polynomial = itertools.accumulate( indexing, acc_product, - initial=(None,), + initial={}, ) - terms[row, 0] = {m: 1.0 for m in monomials} + terms[row, 0] = polynomial poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py new file mode 100644 index 0000000..273add2 --- /dev/null +++ b/polymatrix/expression/mixins/degreeexprmixin.py @@ -0,0 +1,52 @@ + +import abc + +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.utils.getstacklines import FrameSummary + + +class DegreeExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def stack(self) -> tuple[FrameSummary]: + ... + + # overwrites the abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + terms = {} + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + + underlying_terms = underlying.get_poly(row, col) + + if underlying_terms is None or len(underlying_terms) == 0: + continue + + def gen_degrees(): + for monomial, _ in underlying_terms.items(): + yield sum(count for _, count in monomial) + + # degrees = tuple(gen_degrees()) + + terms[row, col] = {tuple(): max(gen_degrees())} + + poly_matrix = init_poly_matrix( + terms=terms, + shape=underlying.shape, + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index bab3c91..1728a2d 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -41,7 +41,7 @@ class DerivativeExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py index ef18e91..7614669 100644 --- a/polymatrix/expression/mixins/determinantexprmixin.py +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -16,12 +16,12 @@ class DeterminantExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` + # # overwrites the abstract method of `ExpressionBaseMixin` # @property # def shape(self) -> tuple[int, int]: # return self.underlying.shape[0], 1 - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py index 1867f8a..046dd5c 100644 --- a/polymatrix/expression/mixins/diagexprmixin.py +++ b/polymatrix/expression/mixins/diagexprmixin.py @@ -19,7 +19,7 @@ class DiagExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 4dc61cd..29d1a0a 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -22,7 +22,7 @@ class DivergenceExprMixin(ExpressionBaseMixin): def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index 90919a5..0f2fada 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -27,7 +27,7 @@ class DivisionExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 96b793c..e6e64b1 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -23,7 +23,7 @@ class ElemMultExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` + # # overwrites the abstract method of `ExpressionBaseMixin` # @property # def shape(self) -> tuple[int, int]: # return self.left.shape @@ -96,7 +96,7 @@ class ElemMultExprMixin(ExpressionBaseMixin): return state, poly_matrix - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 30dc178..0c23c1e 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -21,7 +21,7 @@ class EvalExprMixin(ExpressionBaseMixin): def substitutions(self) -> tuple: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index 4a294ac..e7cd928 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -15,7 +15,7 @@ class EyeExprMixin(ExpressionBaseMixin): def variable(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index 1cf90d8..86b5c6a 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -24,7 +24,7 @@ class FilterExprMixin(ExpressionBaseMixin): def inverse(self) -> bool: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -57,6 +57,9 @@ class FilterExprMixin(ExpressionBaseMixin): if key in predicator_poly: predicator_value = round(predicator_poly[key]) + if isinstance(predicator_value, (float, bool)): + predicator_value = int(predicator_value) + else: predicator_value = 0 diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 8eaf2ef..dfca992 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -20,7 +20,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): def variable(self) -> int: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py index a31655f..7d48c0f 100644 --- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py @@ -18,7 +18,7 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 00f1b64..8aac10b 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -20,7 +20,7 @@ class FromTermsExprMixin(ExpressionBaseMixin): def shape(self) -> tuple[int, int]: pass - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/fromtupleexprmixin.py b/polymatrix/expression/mixins/fromtupleexprmixin.py index 302bcc6..19b3ed3 100644 --- a/polymatrix/expression/mixins/fromtupleexprmixin.py +++ b/polymatrix/expression/mixins/fromtupleexprmixin.py @@ -25,7 +25,7 @@ class FromTupleExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, @@ -36,6 +36,9 @@ class FromTupleExprMixin(ExpressionBaseMixin): for poly_row, col_data in enumerate(self.data): for poly_col, poly_data in enumerate(col_data): + if isinstance(poly_data, (bool, np.bool_)): + poly_data = int(poly_data) + if isinstance(poly_data, (int, float, np.number)): if math.isclose(poly_data, 0): polynomial = {} diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index a3b513f..76a5ee4 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -27,7 +27,7 @@ class GetItemExprMixin(ExpressionBaseMixin): def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py index c6a470f..21d16fd 100644 --- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py +++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py @@ -28,7 +28,7 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin): def filter(self) -> ExpressionBaseMixin | None: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py index 98139ae..194e65e 100644 --- a/polymatrix/expression/mixins/legendreseriesmixin.py +++ b/polymatrix/expression/mixins/legendreseriesmixin.py @@ -23,7 +23,7 @@ class LegendreSeriesMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 0bc99e9..ded61ec 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -51,7 +51,7 @@ class LinearInExprMixin(ExpressionBaseMixin): def ignore_unmatched(self) -> bool: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 8e0f997..4d25a1e 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -21,7 +21,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): def variables(self) -> tuple: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index 1d077e2..516a7a4 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -38,7 +38,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): def variables(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index 7014157..3343a65 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -26,7 +26,7 @@ class MatrixMultExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -69,4 +69,4 @@ class MatrixMultExprMixin(ExpressionBaseMixin): shape=(left.shape[0], right.shape[1]), ) - return state, poly_matrix + return state, poly_matrix diff --git a/polymatrix/expression/mixins/maxdegreeexprmixin.py b/polymatrix/expression/mixins/maxdegreeexprmixin.py deleted file mode 100644 index 0094b9b..0000000 --- a/polymatrix/expression/mixins/maxdegreeexprmixin.py +++ /dev/null @@ -1,43 +0,0 @@ - -import abc - -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate.abc import ExpressionState - - -class MaxDegreeExprMixin(ExpressionBaseMixin): - @property - @abc.abstractmethod - def underlying(self) -> ExpressionBaseMixin: - ... - - # overwrites abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - state, underlying = self.underlying.apply(state=state) - - terms = {} - - for row in range(underlying.shape[0]): - for col in range(underlying.shape[1]): - - underlying_terms = underlying.get_poly(row, col) - if underlying_terms is None: - continue - - def gen_degrees(): - for monomial, _ in underlying_terms.items(): - yield sum(count for _, count in monomial) - - terms[row, col] = {tuple(): max(gen_degrees())} - - poly_matrix = init_poly_matrix( - terms=terms, - shape=underlying.shape, - ) - - return state, poly_matrix diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 5366a83..bfabc4d 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -13,7 +13,7 @@ class MaxExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index c16fcf4..712ff3f 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -19,7 +19,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin): def name(self) -> str: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py index a4cc43d..1aa393b 100644 --- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py +++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py @@ -19,7 +19,7 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin): def name(self) -> str: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index c60ffc2..45be74b 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -1,6 +1,7 @@ import abc import itertools +from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.tooperatorexception import to_operator_exception @@ -27,7 +28,7 @@ class ProductExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, @@ -101,29 +102,36 @@ class ProductExprMixin(ExpressionBaseMixin): for row, indexing in enumerate(indices): - # print(indexing) + # def acc_product(acc, v): + # left_monomials = acc + # polymatrix, row = v - def acc_product(acc, v): - left_monomials = acc - polymatrix, row = v + # right_monomials = polymatrix.get_poly(row, 0).keys() - right_monomials = polymatrix.get_poly(row, 0).keys() + # if left_monomials is (None,): + # return right_monomials - # print(f'{left_monomials=}') - # print(f'{right_monomials=}') + # return tuple(multiply_monomials(left_monomials, right_monomials)) - if left_monomials is (None,): - return right_monomials + def acc_product(left, v): + poly_matrix, row = v - return tuple(multiply_monomials(left_monomials, right_monomials)) + right = poly_matrix.get_poly(row, 0) - *_, monomials = itertools.accumulate( + if len(left) == 0: + return right + + result = {} + multiply_polynomial(left, right, result) + return result + + *_, polynomial = itertools.accumulate( zip(underlying, indexing), acc_product, - initial=(None,), + initial={}, ) - terms[row, 0] = {m: 1.0 for m in monomials} + terms[row, 0] = polynomial poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index f567f5e..6fe1d4b 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -34,7 +34,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index d53613a..ff16275 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -38,7 +38,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): def variables(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index 4726e88..48748be 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -16,7 +16,7 @@ class RepMatExprMixin(ExpressionBaseMixin): def repetition(self) -> tuple[int, int]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 71042a3..915e750 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -20,7 +20,7 @@ class ReshapeExprMixin(ExpressionBaseMixin): def new_shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 51428cb..9250d3c 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -33,7 +33,7 @@ class SetElementAtExprMixin(ExpressionBaseMixin): def value(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index 28fec73..eed3d71 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -14,7 +14,7 @@ class SqueezeExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index 64670e9..1b30dcb 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -23,7 +23,7 @@ class SubstituteExprMixin(ExpressionBaseMixin): def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 67ee6f0..37add0a 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -22,7 +22,7 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): def monomials(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 92be990..812b260 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -21,7 +21,7 @@ class SumExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index bcdf316..769e8c6 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -22,7 +22,7 @@ class SymmetricExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index 40693d3..b8b3d64 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -19,7 +19,7 @@ class ToConstantExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py index 6ba2d3a..d10fcd9 100644 --- a/polymatrix/expression/mixins/toquadraticexprmixin.py +++ b/polymatrix/expression/mixins/toquadraticexprmixin.py @@ -16,7 +16,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index 9768bcc..a61a7e9 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -15,7 +15,7 @@ class ToSortedVariablesExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py index 164f159..8801147 100644 --- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py @@ -20,7 +20,7 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionStateMixin, diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index e1a468d..ce435e9 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -22,7 +22,7 @@ class TransposeExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # overwrites abstract method of `PolyMatrixExprBaseMixin` + # overwrites the abstract method of `PolyMatrixExprBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 05afd55..fdf3921 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -16,7 +16,7 @@ class TruncateExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> ExpressionBaseMixin: + def variables(self) -> ExpressionBaseMixin | None: ... @property @@ -29,13 +29,19 @@ class TruncateExprMixin(ExpressionBaseMixin): def inverse(self) -> bool: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices_from_variable(state, self.variables) + + if self.variables is None: + cond = lambda idx: True + + else: + state, variable_indices = get_variable_indices_from_variable(state, self.variables) + cond = lambda idx: idx in variable_indices terms = {} @@ -50,7 +56,7 @@ class TruncateExprMixin(ExpressionBaseMixin): for monomial, value in polynomial.items(): - degree = sum((count for var_idx, count in monomial if var_idx in variable_indices)) + degree = sum((count for var_idx, count in monomial if cond(var_idx))) if (degree in self.degrees) is not self.inverse: terms_row_col[monomial] = value diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 104a608..98ab663 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -21,7 +21,7 @@ class VStackExprMixin(ExpressionBaseMixin): def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ... - # overwrites abstract method of `ExpressionBaseMixin` + # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, diff --git a/polymatrix/expression/op.py b/polymatrix/expression/op.py index b62253c..f19d310 100644 --- a/polymatrix/expression/op.py +++ b/polymatrix/expression/op.py @@ -24,6 +24,22 @@ def diff( ) +def filter_( + underlying: ExpressionBaseMixin, + predicator: ExpressionBaseMixin, + inverse: bool = None, +) -> ExpressionBaseMixin: + + if inverse is None: + inverse = False + + return polymatrix.expression.impl.FilterExprImpl( + underlying=underlying, + predicator=predicator, + inverse=inverse, + ) + + def legendre( expression: ExpressionBaseMixin, degrees: tuple[int, ...] = None, @@ -62,6 +78,14 @@ def linear_monomials( ) -> ExpressionBaseMixin: return polymatrix.expression.impl.LinearMonomialsExprImpl( - underlying=expression, - variables=variables, - ) + underlying=expression, + variables=variables, + ) + +def degree( + underlying: ExpressionBaseMixin, +): + return polymatrix.expression.impl.DegreeExprImpl( + underlying=underlying, + stack=get_stack_lines(), + ) diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index 5bec552..b02bb11 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -10,18 +10,18 @@ from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -def shape( - expr: Expression, -) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: - def func(state: ExpressionState): - state, polymatrix = expr.apply(state) +# def shape( +# expr: Expression, +# ) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: +# def func(state: ExpressionState): +# state, polymatrix = expr.apply(state) - return state, polymatrix.shape +# return state, polymatrix.shape - return init_state_monad(func) +# return init_state_monad(func) -def to_constant_repr( +def to_constant( expr: Expression, assert_constant: bool = True, ) -> StateMonadMixin[ExpressionState, np.ndarray]: @@ -44,40 +44,40 @@ def to_constant_repr( return init_state_monad(func) -def to_degrees( - expr: Expression, - variables: Expression, -) -> StateMonadMixin[ExpressionState, np.ndarray]: +# def to_degrees( +# expr: Expression, +# variables: Expression, +# ) -> StateMonadMixin[ExpressionState, np.ndarray]: - def func(state: ExpressionState): - state, underlying = expr.apply(state) - state, variable_indices = get_variable_indices_from_variable(state, variables) +# def func(state: ExpressionState): +# state, underlying = expr.apply(state) +# state, variable_indices = get_variable_indices_from_variable(state, variables) - def gen_rows(): - for row in range(underlying.shape[0]): - def gen_cols(): - for col in range(underlying.shape[1]): +# def gen_rows(): +# for row in range(underlying.shape[0]): +# def gen_cols(): +# for col in range(underlying.shape[1]): - def gen_degrees(): - polynomial = underlying.get_poly(row, col) +# def gen_degrees(): +# polynomial = underlying.get_poly(row, col) - if polynomial is None: - yield 0 +# if polynomial is None: +# yield 0 - else: - for monomial, _ in polynomial.items(): - yield sum(count for var, count in monomial if var in variable_indices) +# else: +# for monomial, _ in polynomial.items(): +# yield sum(count for var, count in monomial if var in variable_indices) - yield tuple(set(gen_degrees())) +# yield tuple(set(gen_degrees())) - yield tuple(gen_cols()) +# yield tuple(gen_cols()) - return state, tuple(gen_rows()) +# return state, tuple(gen_rows()) - return init_state_monad(func) +# return init_state_monad(func) -def to_sympy_repr( +def to_sympy( expr: Expression, ) -> StateMonadMixin[ExpressionState, sympy.Expr]: diff --git a/polymatrix/expression/utils/multiplypolynomial.py b/polymatrix/expression/utils/multiplypolynomial.py index 0b6f5b6..e27a124 100644 --- a/polymatrix/expression/utils/multiplypolynomial.py +++ b/polymatrix/expression/utils/multiplypolynomial.py @@ -2,9 +2,13 @@ import itertools import math from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices +from polymatrix.polymatrix.typing import PolynomialData - -def multiply_polynomial(left, right, terms): +def multiply_polynomial(left: PolynomialData, right: PolynomialData, result: PolynomialData): + """ + Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`. + """ + for (left_monomial, left_value), (right_monomial, right_value) \ in itertools.product(left.items(), right.items()): @@ -15,10 +19,10 @@ def multiply_polynomial(left, right, terms): monomial = merge_monomial_indices((left_monomial, right_monomial)) - if monomial not in terms: - terms[monomial] = 0 + if monomial not in result: + result[monomial] = 0 - terms[monomial] += value + result[monomial] += value - if math.isclose(terms[monomial], 0, abs_tol=1e-12): - del terms[monomial] + if math.isclose(result[monomial], 0, abs_tol=1e-12): + del result[monomial] diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py index f44dc9c..fe5946e 100644 --- a/polymatrix/polymatrix/impl.py +++ b/polymatrix/polymatrix/impl.py @@ -1,9 +1,16 @@ import dataclassabc -from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin @dataclassabc.dataclassabc(frozen=True) class PolyMatrixImpl(PolyMatrix): terms: dict shape: tuple[int, ...] + + +@dataclassabc.dataclassabc(frozen=True) +class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin): + polynomial: tuple[tuple[int], float] + shape: tuple[int, int] diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index a097615..1aa0f19 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -18,7 +18,7 @@ class PolyMatrixMixin(abc.ABC): yield (row, col), polynomial @abc.abstractclassmethod - def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None: ... @@ -31,7 +31,21 @@ class PolyMatrixAsDictMixin( def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]: ... - # overwrites abstract method of `PolyMatrixMixin` - def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + # overwrites the abstract method of `PolyMatrixMixin` + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None: if (row, col) in self.terms: return self.terms[row, col] + + +class BroadcastPolyMatrixMixin( + PolyMatrixMixin, + abc.ABC, +): + @property + @abc.abstractmethod + def polynomial(self) -> dict[tuple[int, ...], float]: + ... + + # overwrites the abstract method of `PolyMatrixMixin` + def get_poly(self, col: int, row: int) -> dict[tuple[int, ...], float] | None: + return self.polynomial diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py new file mode 100644 index 0000000..f8031e5 --- /dev/null +++ b/polymatrix/polymatrix/typing.py @@ -0,0 +1,4 @@ + + +PolynomialData = dict[tuple[int, ...], float] +PolynomialMatrixData = dict[tuple[int, int], dict[tuple[int, ...], float]] \ No newline at end of file diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py index 1a13440..dad7662 100644 --- a/polymatrix/statemonad/mixins.py +++ b/polymatrix/statemonad/mixins.py @@ -67,3 +67,6 @@ class StateMonadMixin( def apply(self, state: State) -> Tuple[State, U]: return self.apply_func(state) + def read(self, state: State) -> U: + return self.apply_func(state)[1] + -- cgit v1.2.1