from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, replace from functools import wraps from typing import Callable, Tuple, TypeVar, Generic, Iterable, Any class StateCacheMixin(ABC): @property @abstractmethod def cache(self) -> dict: ... State = TypeVar("State", bound=StateCacheMixin) U = TypeVar("U") V = TypeVar("V") @dataclass(frozen=True) class StateMonad(Generic[State, U]): apply_func: Callable[[State], tuple[State, U]] # TODO: review this. It was added because I want to be able to see what # was passed to the statemonads that are applied to expressions. For # example in to_sympy, you want so see what expression is converted to # sympy. # arguments that were given to the function apply_func. # this field is optional arguments: U | None def __str__(self): if not self.arguments: return f"{str(self.apply_func.__name__)}(...)" args = str(self.arguments) if isinstance(self.arguments, Iterable): args = ", ".join(map(str, self.arguments)) return f"{str(self.apply_func.__name__)}({args})" def map(self, fn: Callable[[U], V]) -> StateMonad[State, V]: @wraps(fn) def internal_map(state: State) -> Tuple[State, U]: n_state, val = self.apply(state) return n_state, fn(val) return replace(self, apply_func=internal_map) def flat_map(self, fn: Callable[[U], StateMonad[State, V]]) -> StateMonad[State, V]: @wraps(fn) def internal_map(state: State) -> Tuple[State, V]: n_state, val = self.apply(state) return fn(val).apply(n_state) return replace(self, apply_func=internal_map) # FIXME: typing def zip(self, other: StateMonad) -> StateMonad: def internal_map(state: State) -> Tuple[State, V]: state, val1 = self.apply(state) state, val2 = other.apply(state) return state, (val1, val2) return replace(self, apply_func=internal_map) # FIXME: typing def cache(self) -> StateMonad: def internal_map(state: State) -> Tuple[State, V]: if self in state.cache: return state, state.cache[self] state, val = self.apply(state) state = replace( state, cache=state.cache | {self: val}, ) return state, val return replace(self, apply_func=internal_map) # NP: Need to find consistent naming and explain somewhere naming convention # NP: of monad operations (is this from scala conventions? I have never used scala) def apply(self, state: State) -> Tuple[State, U]: return self.apply_func(state) # NP: find better name or add explaination somewhere # NP: (I know what it does but the name is very vague) def read(self, state: State) -> U: return self.apply_func(state)[1] def init_state_monad( apply_func: Callable, arguments: Any | None = None ): return StateMonad( apply_func=apply_func, arguments=arguments, )