From 75e0780acdd537e43e1ec8c7228e68ead13515cf Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Sun, 12 May 2024 15:20:12 +0200 Subject: Collapse StateMonad into single class Same reason as previous commit f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8 --- polymatrix/denserepr/from_.py | 5 +- polymatrix/expression/from_.py | 2 +- polymatrix/expression/impl.py | 2 +- polymatrix/expression/init.py | 2 +- polymatrix/expression/mixins/fromstatemonad.py | 9 +- polymatrix/expression/to.py | 9 +- polymatrix/expression/typing.py | 2 +- polymatrix/expressionstate.py | 2 +- polymatrix/statemonad.py | 102 +++++++++++++++++++++++ polymatrix/statemonad/__init__.py | 25 ------ polymatrix/statemonad/abc.py | 7 -- polymatrix/statemonad/impl.py | 20 ----- polymatrix/statemonad/init.py | 13 --- polymatrix/statemonad/mixins.py | 109 ------------------------- 14 files changed, 118 insertions(+), 191 deletions(-) create mode 100644 polymatrix/statemonad.py delete mode 100644 polymatrix/statemonad/__init__.py delete mode 100644 polymatrix/statemonad/abc.py delete mode 100644 polymatrix/statemonad/impl.py delete mode 100644 polymatrix/statemonad/init.py delete mode 100644 polymatrix/statemonad/mixins.py diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index 52040a1..ce60a03 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -1,8 +1,7 @@ import itertools import numpy as np -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.mixins import StateMonadMixin +from polymatrix.statemonad import StateMonad, init_state_monad from polymatrix.expressionstate import ExpressionState from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -20,7 +19,7 @@ def from_polymatrix( expressions: Expression | tuple[Expression], variables: Expression = None, sorted: bool = None, -) -> StateMonadMixin[ +) -> StateMonad[ ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]] ]: """ diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 052ee7a..88181a8 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -7,7 +7,7 @@ import polymatrix.expression.init from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad # NP: this function name makes no sense to me, diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 19fd4dd..6c0edc5 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -1,7 +1,7 @@ import numpy.typing import sympy from typing_extensions import override -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad import dataclassabc diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 7eab999..32de8c1 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -9,7 +9,7 @@ import polymatrix.expression.impl from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolynomialMatrixData -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.utils.formatsubstitutions import format_substitutions diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py index b7d9ea7..47b3a4b 100644 --- a/polymatrix/expression/mixins/fromstatemonad.py +++ b/polymatrix/expression/mixins/fromstatemonad.py @@ -5,7 +5,7 @@ from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad class FromStateMonadMixin(ExpressionBaseMixin): """ @@ -25,13 +25,14 @@ class FromStateMonadMixin(ExpressionBaseMixin): @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, expr = self.monad.apply(state) - # Case when monad wraps function - # f: ExpressionState -> (State, Expression) + # Case when monad wraps functions + # f: ExpressionState -> (ExpressionState, Expression) + # f: ExpressionState -> (ExpressionState, Mixin) if isinstance(expr, Expression | ExpressionBaseMixin): return expr.apply(state) # Case when monad wraps function - # f: ExpressionState -> (State, PolyMatrix) + # f: ExpressionState -> (ExpressionState, PolyMatrix) elif isinstance(expr, PolyMatrixMixin): return state, expr diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index 4509d1e..7f9bac7 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -4,13 +4,12 @@ import numpy as np from polymatrix.expression.expression import Expression from polymatrix.expressionstate import ExpressionState -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.mixins import StateMonadMixin +from polymatrix.statemonad import StateMonad, init_state_monad def shape( expr: Expression, -) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: +) -> StateMonad[ExpressionState, tuple[int, ...]]: def func(state: ExpressionState): state, polymatrix = expr.apply(state) @@ -22,7 +21,7 @@ def shape( def to_constant( expr: Expression, assert_constant: bool = True, -) -> StateMonadMixin[ExpressionState, np.ndarray]: +) -> StateMonad[ExpressionState, np.ndarray]: def func(state: ExpressionState): state, underlying = expr.apply(state) @@ -43,7 +42,7 @@ def to_constant( def to_sympy( expr: Expression, -) -> StateMonadMixin[ExpressionState, sympy.Expr | sympy.Matrix]: +) -> StateMonad[ExpressionState, sympy.Expr | sympy.Matrix]: def polymatrix_to_sympy(state: ExpressionState) -> tuple[ExpressionState, sympy.Expr | sympy.Matrix]: diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py index 58c5ca2..af762a0 100644 --- a/polymatrix/expression/typing.py +++ b/polymatrix/expression/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad import numpy.typing as npt import sympy diff --git a/polymatrix/expressionstate.py b/polymatrix/expressionstate.py index 931f282..7d0eb26 100644 --- a/polymatrix/expressionstate.py +++ b/polymatrix/expressionstate.py @@ -10,7 +10,7 @@ from polymatrix.variable.abc import Variable from polymatrix.utils.deprecation import deprecated from polymatrix.polymatrix.index import MonomialIndex, VariableIndex -from polymatrix.statemonad.mixins import StateCacheMixin +from polymatrix.statemonad import StateCacheMixin # TODO: move to typing submodule class IndexRange(NamedTuple): diff --git a/polymatrix/statemonad.py b/polymatrix/statemonad.py new file mode 100644 index 0000000..dc680de --- /dev/null +++ b/polymatrix/statemonad.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, replace +from functools import wraps +from typing import Callable, Tuple, TypeVar, Generic, Iterable, Any + + +class StateCacheMixin(ABC): + @property + @abstractmethod + def cache(self) -> dict: ... + + +State = TypeVar("State", bound=StateCacheMixin) +U = TypeVar("U") +V = TypeVar("V") + + +@dataclass +class StateMonad(Generic[State, U]): + apply_func: Callable[[State], tuple[State, U]] + + # TODO: review this. It was added because I want to be able to see what + # was passed to the statemonads that are applied to expressions. For + # example in to_sympy, you want so see what expression is converted to + # sympy. + # arguments that were given to the function apply_func. + # this field is optional + arguments: U | None + + def __str__(self): + if not self.arguments: + return f"{str(self.apply_func.__name__)}(...)" + + args = str(self.arguments) + if isinstance(self.arguments, Iterable): + args = ", ".join(map(str, self.arguments)) + + return f"{str(self.apply_func.__name__)}({args})" + + def map(self, fn: Callable[[U], V]) -> StateMonad[State, V]: + @wraps(fn) + def internal_map(state: State) -> Tuple[State, U]: + n_state, val = self.apply(state) + return n_state, fn(val) + + return replace(self, apply_func=internal_map) + + def flat_map(self, fn: Callable[[U], StateMonad[State, V]]) -> StateMonad[State, V]: + @wraps(fn) + def internal_map(state: State) -> Tuple[State, V]: + n_state, val = self.apply(state) + return fn(val).apply(n_state) + + return replace(self, apply_func=internal_map) + + # FIXME: typing + def zip(self, other: StateMonad) -> StateMonad: + def internal_map(state: State) -> Tuple[State, V]: + state, val1 = self.apply(state) + state, val2 = other.apply(state) + return state, (val1, val2) + + return replace(self, apply_func=internal_map) + + # FIXME: typing + def cache(self) -> StateMonad: + def internal_map(state: State) -> Tuple[State, V]: + if self in state.cache: + return state, state.cache[self] + + state, val = self.apply(state) + + state = replace( + state, + cache=state.cache | {self: val}, + ) + + return state, val + + return replace(self, apply_func=internal_map) + + # NP: Need to find consistent naming and explain somewhere naming convention + # NP: of monad operations (is this from scala conventions? I have never used scala) + def apply(self, state: State) -> Tuple[State, U]: + return self.apply_func(state) + + # NP: find better name or add explaination somewhere + # NP: (I know what it does but the name is very vague) + def read(self, state: State) -> U: + return self.apply_func(state)[1] + + +def init_state_monad( + apply_func: Callable, + arguments: Any | None = None +): + return StateMonad( + apply_func=apply_func, + arguments=arguments, + ) diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py deleted file mode 100644 index 48e369e..0000000 --- a/polymatrix/statemonad/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.abc import StateMonad - -# NP: this is the unit operation for the monad, why not move it inside the -# NP: monad class? -def from_(val): - def func(state): - return state, val - - return init_state_monad(func) - - -# NP: duplicate-ish with StateMonadMixin.zip, this one is more generic -# NP: consider moving this into the mixin and deleting this function -def zip(monads: tuple[StateMonad]): - def zip_func(state): - values = tuple() - - for monad in monads: - state, val = monad.apply(state) - values += (val,) - - return state, values - - return init_state_monad(zip_func) diff --git a/polymatrix/statemonad/abc.py b/polymatrix/statemonad/abc.py deleted file mode 100644 index 671b5bd..0000000 --- a/polymatrix/statemonad/abc.py +++ /dev/null @@ -1,7 +0,0 @@ -import abc - -from polymatrix.statemonad.mixins import StateMonadMixin - - -class StateMonad(StateMonadMixin, abc.ABC): - pass diff --git a/polymatrix/statemonad/impl.py b/polymatrix/statemonad/impl.py deleted file mode 100644 index 817ce47..0000000 --- a/polymatrix/statemonad/impl.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Callable, Any, Iterable -import dataclassabc - -from polymatrix.statemonad.abc import StateMonad - - -@dataclassabc.dataclassabc(frozen=True) -class StateMonadImpl(StateMonad): - apply_func: Callable - arguments: Any | None = None - - def __str__(self): - if not self.arguments: - return f"{str(self.apply_func.__name__)}(...)" - - args = str(self.arguments) - if isinstance(self.arguments, Iterable): - args = ", ".join(map(str, self.arguments)) - - return f"{str(self.apply_func.__name__)}({args})" diff --git a/polymatrix/statemonad/init.py b/polymatrix/statemonad/init.py deleted file mode 100644 index 0abe6e4..0000000 --- a/polymatrix/statemonad/init.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Callable, Any - -from polymatrix.statemonad.impl import StateMonadImpl - - -def init_state_monad( - apply_func: Callable, - arguments: Any | None = None -): - return StateMonadImpl( - apply_func=apply_func, - arguments=arguments, - ) diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py deleted file mode 100644 index 5703241..0000000 --- a/polymatrix/statemonad/mixins.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import replace -from functools import wraps -from typing import Callable, Tuple, TypeVar, Generic - - -class StateCacheMixin(ABC): - @property - @abstractmethod - def cache(self) -> dict: ... - - -State = TypeVar("State", bound=StateCacheMixin) -U = TypeVar("U") -V = TypeVar("V") - - -# NP: Monadic type, use shorthand M for StateMonadMixin. -# NP: (will use haskell-like notation in this comment) -# -# NP: typical operations for monad (do you agree wit this? If not please explain your conventions) -# -# NP: - unit operation (aka return, bad name) take a value `u :: U` and make a `m :: M[U]` -# NP: you often call operation "from" -# -# NP: - map operation (aka lift) take a function (U -> V) and make a new function (M[U] -> M[V]) -# -# NP: - bind operation (aka flat map) take a function (U -> M[V]) -# NP: and make a new function (M[U] -> M[V]) -# NP you call this operation flat_map -# -# NP: - apply operation take a M[U -> V] and make (M[U] -> M[V]) -# -# NP: - zip operation take a function (U -> V -> W) -# NP: and make a new function (M[U] -> M[V] -> M[W]) -# -# NP: TODO: text comparing the above to implementation below -class StateMonadMixin(Generic[State, U], ABC): - @property - @abstractmethod - def apply_func(self) -> Callable[[State], tuple[State, U]]: - # NP: TODO comment - ... - - @property - @abstractmethod - def arguments(self) -> U | None: - # arguments that were given to the function apply_func. - # this field is optional - - # TODO: review this. It was added because I want to be able to see what - # was passed to the statemonads that are applied to expressions. For - # example in to_sympy, you want so see what expression is converted to - # sympy. - ... - - def map(self, fn: Callable[[U], V]) -> StateMonadMixin[State, V]: - @wraps(fn) - def internal_map(state: State) -> Tuple[State, U]: - n_state, val = self.apply(state) - return n_state, fn(val) - - return replace(self, apply_func=internal_map) - - def flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]: - @wraps(fn) - def internal_map(state: State) -> Tuple[State, V]: - n_state, val = self.apply(state) - return fn(val).apply(n_state) - - return replace(self, apply_func=internal_map) - - # FIXME: typing - def zip(self, other: StateMonadMixin) -> StateMonadMixin: - def internal_map(state: State) -> Tuple[State, V]: - state, val1 = self.apply(state) - state, val2 = other.apply(state) - return state, (val1, val2) - - return replace(self, apply_func=internal_map) - - # FIXME: typing - def cache(self) -> StateMonadMixin: - def internal_map(state: State) -> Tuple[State, V]: - if self in state.cache: - return state, state.cache[self] - - state, val = self.apply(state) - - state = replace( - state, - cache=state.cache | {self: val}, - ) - - return state, val - - return replace(self, apply_func=internal_map) - - # NP: Need to find consistent naming and explain somewhere naming convention - # NP: of monad operations (is this from scala conventions? I have never used scala) - def apply(self, state: State) -> Tuple[State, U]: - return self.apply_func(state) - - # NP: find better name or add explaination somewhere - # NP: (I know what it does but the name is very vague) - def read(self, state: State) -> U: - return self.apply_func(state)[1] -- cgit v1.2.1