diff options
author | Nao Pross <np@0hm.ch> | 2024-05-12 15:20:12 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-12 15:22:49 +0200 |
commit | 75e0780acdd537e43e1ec8c7228e68ead13515cf (patch) | |
tree | 9632fadf6205f390ad4cf532141c99c5f0d85317 | |
parent | Collapse ExpressionState to a single class (diff) | |
download | polymatrix-75e0780acdd537e43e1ec8c7228e68ead13515cf.tar.gz polymatrix-75e0780acdd537e43e1ec8c7228e68ead13515cf.zip |
Collapse StateMonad into single class
Same reason as previous commit f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8
-rw-r--r-- | polymatrix/denserepr/from_.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/mixins/fromstatemonad.py | 9 | ||||
-rw-r--r-- | polymatrix/expression/to.py | 9 | ||||
-rw-r--r-- | polymatrix/expression/typing.py | 2 | ||||
-rw-r--r-- | polymatrix/expressionstate.py | 2 | ||||
-rw-r--r-- | polymatrix/statemonad.py (renamed from polymatrix/statemonad/mixins.py) | 77 | ||||
-rw-r--r-- | polymatrix/statemonad/__init__.py | 25 | ||||
-rw-r--r-- | polymatrix/statemonad/abc.py | 7 | ||||
-rw-r--r-- | polymatrix/statemonad/impl.py | 20 | ||||
-rw-r--r-- | polymatrix/statemonad/init.py | 13 |
13 files changed, 51 insertions, 124 deletions
diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index 52040a1..ce60a03 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -1,8 +1,7 @@ import itertools import numpy as np -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.mixins import StateMonadMixin +from polymatrix.statemonad import StateMonad, init_state_monad from polymatrix.expressionstate import ExpressionState from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -20,7 +19,7 @@ def from_polymatrix( expressions: Expression | tuple[Expression], variables: Expression = None, sorted: bool = None, -) -> StateMonadMixin[ +) -> StateMonad[ ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]] ]: """ diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 052ee7a..88181a8 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -7,7 +7,7 @@ import polymatrix.expression.init from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad # NP: this function name makes no sense to me, diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 19fd4dd..6c0edc5 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -1,7 +1,7 @@ import numpy.typing import sympy from typing_extensions import override -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad import dataclassabc diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 7eab999..32de8c1 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -9,7 +9,7 @@ import polymatrix.expression.impl from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.index import PolynomialMatrixData -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.utils.formatsubstitutions import format_substitutions diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py index b7d9ea7..47b3a4b 100644 --- a/polymatrix/expression/mixins/fromstatemonad.py +++ b/polymatrix/expression/mixins/fromstatemonad.py @@ -5,7 +5,7 @@ from polymatrix.expression.expression import Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad class FromStateMonadMixin(ExpressionBaseMixin): """ @@ -25,13 +25,14 @@ class FromStateMonadMixin(ExpressionBaseMixin): @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: state, expr = self.monad.apply(state) - # Case when monad wraps function - # f: ExpressionState -> (State, Expression) + # Case when monad wraps functions + # f: ExpressionState -> (ExpressionState, Expression) + # f: ExpressionState -> (ExpressionState, Mixin) if isinstance(expr, Expression | ExpressionBaseMixin): return expr.apply(state) # Case when monad wraps function - # f: ExpressionState -> (State, PolyMatrix) + # f: ExpressionState -> (ExpressionState, PolyMatrix) elif isinstance(expr, PolyMatrixMixin): return state, expr diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index 4509d1e..7f9bac7 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -4,13 +4,12 @@ import numpy as np from polymatrix.expression.expression import Expression from polymatrix.expressionstate import ExpressionState -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.mixins import StateMonadMixin +from polymatrix.statemonad import StateMonad, init_state_monad def shape( expr: Expression, -) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: +) -> StateMonad[ExpressionState, tuple[int, ...]]: def func(state: ExpressionState): state, polymatrix = expr.apply(state) @@ -22,7 +21,7 @@ def shape( def to_constant( expr: Expression, assert_constant: bool = True, -) -> StateMonadMixin[ExpressionState, np.ndarray]: +) -> StateMonad[ExpressionState, np.ndarray]: def func(state: ExpressionState): state, underlying = expr.apply(state) @@ -43,7 +42,7 @@ def to_constant( def to_sympy( expr: Expression, -) -> StateMonadMixin[ExpressionState, sympy.Expr | sympy.Matrix]: +) -> StateMonad[ExpressionState, sympy.Expr | sympy.Matrix]: def polymatrix_to_sympy(state: ExpressionState) -> tuple[ExpressionState, sympy.Expr | sympy.Matrix]: diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py index 58c5ca2..af762a0 100644 --- a/polymatrix/expression/typing.py +++ b/polymatrix/expression/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.statemonad.abc import StateMonad +from polymatrix.statemonad import StateMonad import numpy.typing as npt import sympy diff --git a/polymatrix/expressionstate.py b/polymatrix/expressionstate.py index 931f282..7d0eb26 100644 --- a/polymatrix/expressionstate.py +++ b/polymatrix/expressionstate.py @@ -10,7 +10,7 @@ from polymatrix.variable.abc import Variable from polymatrix.utils.deprecation import deprecated from polymatrix.polymatrix.index import MonomialIndex, VariableIndex -from polymatrix.statemonad.mixins import StateCacheMixin +from polymatrix.statemonad import StateCacheMixin # TODO: move to typing submodule class IndexRange(NamedTuple): diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad.py index 5703241..dc680de 100644 --- a/polymatrix/statemonad/mixins.py +++ b/polymatrix/statemonad.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import replace +from dataclasses import dataclass, replace from functools import wraps -from typing import Callable, Tuple, TypeVar, Generic +from typing import Callable, Tuple, TypeVar, Generic, Iterable, Any class StateCacheMixin(ABC): @@ -17,46 +17,29 @@ U = TypeVar("U") V = TypeVar("V") -# NP: Monadic type, use shorthand M for StateMonadMixin. -# NP: (will use haskell-like notation in this comment) -# -# NP: typical operations for monad (do you agree wit this? If not please explain your conventions) -# -# NP: - unit operation (aka return, bad name) take a value `u :: U` and make a `m :: M[U]` -# NP: you often call operation "from" -# -# NP: - map operation (aka lift) take a function (U -> V) and make a new function (M[U] -> M[V]) -# -# NP: - bind operation (aka flat map) take a function (U -> M[V]) -# NP: and make a new function (M[U] -> M[V]) -# NP you call this operation flat_map -# -# NP: - apply operation take a M[U -> V] and make (M[U] -> M[V]) -# -# NP: - zip operation take a function (U -> V -> W) -# NP: and make a new function (M[U] -> M[V] -> M[W]) -# -# NP: TODO: text comparing the above to implementation below -class StateMonadMixin(Generic[State, U], ABC): - @property - @abstractmethod - def apply_func(self) -> Callable[[State], tuple[State, U]]: - # NP: TODO comment - ... +@dataclass +class StateMonad(Generic[State, U]): + apply_func: Callable[[State], tuple[State, U]] - @property - @abstractmethod - def arguments(self) -> U | None: - # arguments that were given to the function apply_func. - # this field is optional + # TODO: review this. It was added because I want to be able to see what + # was passed to the statemonads that are applied to expressions. For + # example in to_sympy, you want so see what expression is converted to + # sympy. + # arguments that were given to the function apply_func. + # this field is optional + arguments: U | None - # TODO: review this. It was added because I want to be able to see what - # was passed to the statemonads that are applied to expressions. For - # example in to_sympy, you want so see what expression is converted to - # sympy. - ... + def __str__(self): + if not self.arguments: + return f"{str(self.apply_func.__name__)}(...)" - def map(self, fn: Callable[[U], V]) -> StateMonadMixin[State, V]: + args = str(self.arguments) + if isinstance(self.arguments, Iterable): + args = ", ".join(map(str, self.arguments)) + + return f"{str(self.apply_func.__name__)}({args})" + + def map(self, fn: Callable[[U], V]) -> StateMonad[State, V]: @wraps(fn) def internal_map(state: State) -> Tuple[State, U]: n_state, val = self.apply(state) @@ -64,7 +47,7 @@ class StateMonadMixin(Generic[State, U], ABC): return replace(self, apply_func=internal_map) - def flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]: + def flat_map(self, fn: Callable[[U], StateMonad[State, V]]) -> StateMonad[State, V]: @wraps(fn) def internal_map(state: State) -> Tuple[State, V]: n_state, val = self.apply(state) @@ -73,7 +56,7 @@ class StateMonadMixin(Generic[State, U], ABC): return replace(self, apply_func=internal_map) # FIXME: typing - def zip(self, other: StateMonadMixin) -> StateMonadMixin: + def zip(self, other: StateMonad) -> StateMonad: def internal_map(state: State) -> Tuple[State, V]: state, val1 = self.apply(state) state, val2 = other.apply(state) @@ -82,7 +65,7 @@ class StateMonadMixin(Generic[State, U], ABC): return replace(self, apply_func=internal_map) # FIXME: typing - def cache(self) -> StateMonadMixin: + def cache(self) -> StateMonad: def internal_map(state: State) -> Tuple[State, V]: if self in state.cache: return state, state.cache[self] @@ -107,3 +90,13 @@ class StateMonadMixin(Generic[State, U], ABC): # NP: (I know what it does but the name is very vague) def read(self, state: State) -> U: return self.apply_func(state)[1] + + +def init_state_monad( + apply_func: Callable, + arguments: Any | None = None +): + return StateMonad( + apply_func=apply_func, + arguments=arguments, + ) diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py deleted file mode 100644 index 48e369e..0000000 --- a/polymatrix/statemonad/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from polymatrix.statemonad.init import init_state_monad -from polymatrix.statemonad.abc import StateMonad - -# NP: this is the unit operation for the monad, why not move it inside the -# NP: monad class? -def from_(val): - def func(state): - return state, val - - return init_state_monad(func) - - -# NP: duplicate-ish with StateMonadMixin.zip, this one is more generic -# NP: consider moving this into the mixin and deleting this function -def zip(monads: tuple[StateMonad]): - def zip_func(state): - values = tuple() - - for monad in monads: - state, val = monad.apply(state) - values += (val,) - - return state, values - - return init_state_monad(zip_func) diff --git a/polymatrix/statemonad/abc.py b/polymatrix/statemonad/abc.py deleted file mode 100644 index 671b5bd..0000000 --- a/polymatrix/statemonad/abc.py +++ /dev/null @@ -1,7 +0,0 @@ -import abc - -from polymatrix.statemonad.mixins import StateMonadMixin - - -class StateMonad(StateMonadMixin, abc.ABC): - pass diff --git a/polymatrix/statemonad/impl.py b/polymatrix/statemonad/impl.py deleted file mode 100644 index 817ce47..0000000 --- a/polymatrix/statemonad/impl.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Callable, Any, Iterable -import dataclassabc - -from polymatrix.statemonad.abc import StateMonad - - -@dataclassabc.dataclassabc(frozen=True) -class StateMonadImpl(StateMonad): - apply_func: Callable - arguments: Any | None = None - - def __str__(self): - if not self.arguments: - return f"{str(self.apply_func.__name__)}(...)" - - args = str(self.arguments) - if isinstance(self.arguments, Iterable): - args = ", ".join(map(str, self.arguments)) - - return f"{str(self.apply_func.__name__)}({args})" diff --git a/polymatrix/statemonad/init.py b/polymatrix/statemonad/init.py deleted file mode 100644 index 0abe6e4..0000000 --- a/polymatrix/statemonad/init.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Callable, Any - -from polymatrix.statemonad.impl import StateMonadImpl - - -def init_state_monad( - apply_func: Callable, - arguments: Any | None = None -): - return StateMonadImpl( - apply_func=apply_func, - arguments=arguments, - ) |