diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/expression.py | 96 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 2 | ||||
-rw-r--r-- | polymatrix/statemonad/mixins.py | 45 |
3 files changed, 58 insertions, 85 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 52c3da6..abe37aa 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import dataclasses -import typing import numpy as np +from numpy.typing import NDArray from abc import ABC, abstractmethod from dataclassabc import dataclassabc from typing_extensions import override @@ -42,7 +44,7 @@ class Expression(ExpressionBaseMixin, ABC): def read(self, state: ExpressionState) -> PolyMatrix: return self.apply(state)[1] - def __add__(self, other: ExpressionBaseMixin) -> "Expression": + def __add__(self, other: ExpressionBaseMixin) -> Expression: return self._binary(polymatrix.expression.init.init_addition_expr, self, other) def __getattr__(self, name): @@ -57,7 +59,8 @@ class Expression(ExpressionBaseMixin, ABC): else: return attr - def __getitem__(self, key: tuple[int, int]): + def __getitem__(self, key: tuple[int, int]) -> Expression: + # FIXME: typing for key is incorrect, could be a slice return self.copy( underlying=polymatrix.expression.init.init_get_item_expr( underlying=self.underlying, @@ -65,25 +68,16 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def __matmul__( - self, other: ExpressionBaseMixin | np.ndarray - ) -> "Expression": + def __matmul__(self, other: ExpressionBaseMixin | np.ndarray) -> Expression: return self._binary( polymatrix.expression.init.init_matrix_mult_expr, self, other ) - def __mul__(self, other) -> "Expression": + def __mul__(self, other) -> Expression: return self._binary(polymatrix.expression.init.init_elem_mult_expr, self, other) - def __pow__(self, num): - curr = 1 - - # FIXME: this only works for positive integral powers, consider raising - # an error if the power is not a positive integer - for _ in range(num): - curr = curr * self - - return curr + def __pow__(self, exponent: Expression | int | float) -> Expression: + return self._binary(polymatrix.expression.init.init_power_expr, self, exponent) def __neg__(self): return self * (-1) @@ -132,7 +126,7 @@ class Expression(ExpressionBaseMixin, ABC): return right.copy(underlying=op(left, right.underlying, stack)) - def cache(self) -> "Expression": + def cache(self) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_cache_expr( underlying=self.underlying, @@ -171,11 +165,9 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def diff( - self, - variables: "Expression", - introduce_derivatives: bool = None, - ) -> "Expression": + # FIXME: sometime variables is a tuple, sometimes an expression. make consistent. + # FIXME: this function is probably broken + def diff(self, variables: Expression, introduce_derivatives: bool | None = None) -> Expression: return self.copy( underlying=diff( expression=self, @@ -184,10 +176,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def divergence( - self, - variables: tuple, - ) -> "Expression": + def divergence(self, variables: tuple) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_divergence_expr( underlying=self.underlying, @@ -195,11 +184,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def eval( - self, - variable: tuple, - value: tuple[float, ...] = None, - ) -> "Expression": + def eval(self, variable: tuple, value: tuple[float, ...] | None = None,) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_eval_expr( underlying=self.underlying, @@ -211,9 +196,9 @@ class Expression(ExpressionBaseMixin, ABC): # also applies to monomials (and variables?) def filter( self, - predicator: "Expression", - inverse: bool = None, - ) -> "Expression": + predicator: Expression, + inverse: bool | None = None, + ) -> Expression: return self.copy( underlying=filter_( underlying=self.underlying, @@ -223,7 +208,7 @@ class Expression(ExpressionBaseMixin, ABC): ) # only applies to symmetric matrix - def from_symmetric_matrix(self) -> "Expression": + def from_symmetric_matrix(self) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr( underlying=self.underlying, @@ -231,11 +216,7 @@ class Expression(ExpressionBaseMixin, ABC): ) # only applies to monomials - def half_newton_polytope( - self, - variables: "Expression", - filter: "Expression | None" = None, - ) -> "Expression": + def half_newton_polytope(self, variables: Expression, filter: Expression | None = None,) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_half_newton_polytope_expr( monomials=self.underlying, @@ -245,8 +226,11 @@ class Expression(ExpressionBaseMixin, ABC): ) def integrate( - self, variables: "Expression", from_: tuple[float, ...], to: tuple[float, ...] - ) -> "Expression": + self, + variables: Expression, + from_: tuple[float, ...], + to: tuple[float, ...] + ) -> Expression: return self.copy( underlying=integrate( expression=self, @@ -256,7 +240,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def linear_matrix_in(self, variable: "Expression") -> "Expression": + def linear_matrix_in(self, variable: Expression) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_linear_matrix_in_expr( underlying=self.underlying, @@ -264,10 +248,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def linear_monomials( - self, - variables: "Expression", - ) -> "Expression": + def linear_monomials(self, variables: Expression) -> Expression: return self.copy( underlying=linear_monomials( expression=self.underlying, @@ -277,10 +258,10 @@ class Expression(ExpressionBaseMixin, ABC): def linear_in( self, - variables: "Expression", - monomials: "Expression" = None, - ignore_unmatched: bool = None, - ) -> "Expression": + variables: Expression, + monomials: Expression | None = None, + ignore_unmatched: bool | None = None, + ) -> Expression: return self.copy( underlying=linear_in( expression=self.underlying, @@ -290,10 +271,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def legendre( - self, - degrees: tuple[int, ...] = None, - ) -> "Expression": + def legendre(self, degrees: tuple[int, ...] | None = None) -> Expression: return self.copy( underlying=legendre( expression=self.underlying, @@ -301,14 +279,14 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def max(self) -> "Expression": + def max(self) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_max_expr( underlying=self.underlying, ), ) - def parametrize(self, name: str = None) -> "Expression": + def parametrize(self, name: str | None = None) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_parametrize_expr( underlying=self.underlying, @@ -316,9 +294,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def quadratic_in( - self, variables: "Expression", monomials: "Expression" = None - ) -> "Expression": + def quadratic_in(self, variables: Expression, monomials: Expression | None = None) -> Expression: if monomials is None: monomials = self.quadratic_monomials(variables) diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 52f9d06..052ee7a 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -33,7 +33,7 @@ def from_( ) -def from_statemonad(monad: StateMonad): +def from_statemonad(monad: StateMonad) -> Expression: return init_expression(polymatrix.expression.init.init_from_statemonad(monad)) diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py index 68d1db6..5703241 100644 --- a/polymatrix/statemonad/mixins.py +++ b/polymatrix/statemonad/mixins.py @@ -1,11 +1,14 @@ -import abc -import dataclasses +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import replace +from functools import wraps from typing import Callable, Tuple, TypeVar, Generic -class StateCacheMixin(abc.ABC): +class StateCacheMixin(ABC): @property - @abc.abstractmethod + @abstractmethod def cache(self) -> dict: ... @@ -34,18 +37,15 @@ V = TypeVar("V") # NP: and make a new function (M[U] -> M[V] -> M[W]) # # NP: TODO: text comparing the above to implementation below -class StateMonadMixin( - Generic[State, U], - abc.ABC, -): +class StateMonadMixin(Generic[State, U], ABC): @property - @abc.abstractmethod + @abstractmethod def apply_func(self) -> Callable[[State], tuple[State, U]]: # NP: TODO comment ... @property - @abc.abstractmethod + @abstractmethod def arguments(self) -> U | None: # arguments that were given to the function apply_func. # this field is optional @@ -56,50 +56,47 @@ class StateMonadMixin( # sympy. ... - # NP: typing, use from __future__ import annotations - def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]': - # NP: add functools.wrap(fn) decorator to copy docstrings etc. + def map(self, fn: Callable[[U], V]) -> StateMonadMixin[State, V]: + @wraps(fn) def internal_map(state: State) -> Tuple[State, U]: n_state, val = self.apply(state) return n_state, fn(val) - return dataclasses.replace(self, apply_func=internal_map) + return replace(self, apply_func=internal_map) - # NP: shouldn't typing be - # NP: flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V] - def flat_map(self, fn: Callable[[U], 'StateMonadMixin']) -> 'StateMonadMixin[State, V]': - # NP: add functools.wrap(fn) decorator + def flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]: + @wraps(fn) def internal_map(state: State) -> Tuple[State, V]: n_state, val = self.apply(state) return fn(val).apply(n_state) - return dataclasses.replace(self, apply_func=internal_map) + return replace(self, apply_func=internal_map) # FIXME: typing - def zip(self, other: 'StateMonadMixin') -> 'StateMonadMixin': + def zip(self, other: StateMonadMixin) -> StateMonadMixin: def internal_map(state: State) -> Tuple[State, V]: state, val1 = self.apply(state) state, val2 = other.apply(state) return state, (val1, val2) - return dataclasses.replace(self, apply_func=internal_map) + return replace(self, apply_func=internal_map) # FIXME: typing - def cache(self) -> 'StateMonadMixin': + def cache(self) -> StateMonadMixin: def internal_map(state: State) -> Tuple[State, V]: if self in state.cache: return state, state.cache[self] state, val = self.apply(state) - state = dataclasses.replace( + state = replace( state, cache=state.cache | {self: val}, ) return state, val - return dataclasses.replace(self, apply_func=internal_map) + return replace(self, apply_func=internal_map) # NP: Need to find consistent naming and explain somewhere naming convention # NP: of monad operations (is this from scala conventions? I have never used scala) |