summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-12 15:20:12 +0200
committerNao Pross <np@0hm.ch>2024-05-12 15:22:49 +0200
commit75e0780acdd537e43e1ec8c7228e68ead13515cf (patch)
tree9632fadf6205f390ad4cf532141c99c5f0d85317
parentCollapse ExpressionState to a single class (diff)
downloadpolymatrix-75e0780acdd537e43e1ec8c7228e68ead13515cf.tar.gz
polymatrix-75e0780acdd537e43e1ec8c7228e68ead13515cf.zip
Collapse StateMonad into single class
Same reason as previous commit f094e4d91b44fc1e8b5f11aac2dd8073ba024fc8
Diffstat (limited to '')
-rw-r--r--polymatrix/denserepr/from_.py5
-rw-r--r--polymatrix/expression/from_.py2
-rw-r--r--polymatrix/expression/impl.py2
-rw-r--r--polymatrix/expression/init.py2
-rw-r--r--polymatrix/expression/mixins/fromstatemonad.py9
-rw-r--r--polymatrix/expression/to.py9
-rw-r--r--polymatrix/expression/typing.py2
-rw-r--r--polymatrix/expressionstate.py2
-rw-r--r--polymatrix/statemonad.py (renamed from polymatrix/statemonad/mixins.py)77
-rw-r--r--polymatrix/statemonad/__init__.py25
-rw-r--r--polymatrix/statemonad/abc.py7
-rw-r--r--polymatrix/statemonad/impl.py20
-rw-r--r--polymatrix/statemonad/init.py13
13 files changed, 51 insertions, 124 deletions
diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py
index 52040a1..ce60a03 100644
--- a/polymatrix/denserepr/from_.py
+++ b/polymatrix/denserepr/from_.py
@@ -1,8 +1,7 @@
import itertools
import numpy as np
-from polymatrix.statemonad.init import init_state_monad
-from polymatrix.statemonad.mixins import StateMonadMixin
+from polymatrix.statemonad import StateMonad, init_state_monad
from polymatrix.expressionstate import ExpressionState
from polymatrix.expression.expression import Expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -20,7 +19,7 @@ def from_polymatrix(
expressions: Expression | tuple[Expression],
variables: Expression = None,
sorted: bool = None,
-) -> StateMonadMixin[
+) -> StateMonad[
ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]
]:
"""
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 052ee7a..88181a8 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -7,7 +7,7 @@ import polymatrix.expression.init
from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.statemonad.abc import StateMonad
+from polymatrix.statemonad import StateMonad
# NP: this function name makes no sense to me,
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 19fd4dd..6c0edc5 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -1,7 +1,7 @@
import numpy.typing
import sympy
from typing_extensions import override
-from polymatrix.statemonad.abc import StateMonad
+from polymatrix.statemonad import StateMonad
import dataclassabc
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 7eab999..32de8c1 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -9,7 +9,7 @@ import polymatrix.expression.impl
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.index import PolynomialMatrixData
-from polymatrix.statemonad.abc import StateMonad
+from polymatrix.statemonad import StateMonad
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.utils.formatsubstitutions import format_substitutions
diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py
index b7d9ea7..47b3a4b 100644
--- a/polymatrix/expression/mixins/fromstatemonad.py
+++ b/polymatrix/expression/mixins/fromstatemonad.py
@@ -5,7 +5,7 @@ from polymatrix.expression.expression import Expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate import ExpressionState
from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.statemonad.abc import StateMonad
+from polymatrix.statemonad import StateMonad
class FromStateMonadMixin(ExpressionBaseMixin):
"""
@@ -25,13 +25,14 @@ class FromStateMonadMixin(ExpressionBaseMixin):
@override
def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
state, expr = self.monad.apply(state)
- # Case when monad wraps function
- # f: ExpressionState -> (State, Expression)
+ # Case when monad wraps functions
+ # f: ExpressionState -> (ExpressionState, Expression)
+ # f: ExpressionState -> (ExpressionState, Mixin)
if isinstance(expr, Expression | ExpressionBaseMixin):
return expr.apply(state)
# Case when monad wraps function
- # f: ExpressionState -> (State, PolyMatrix)
+ # f: ExpressionState -> (ExpressionState, PolyMatrix)
elif isinstance(expr, PolyMatrixMixin):
return state, expr
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index 4509d1e..7f9bac7 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -4,13 +4,12 @@ import numpy as np
from polymatrix.expression.expression import Expression
from polymatrix.expressionstate import ExpressionState
-from polymatrix.statemonad.init import init_state_monad
-from polymatrix.statemonad.mixins import StateMonadMixin
+from polymatrix.statemonad import StateMonad, init_state_monad
def shape(
expr: Expression,
-) -> StateMonadMixin[ExpressionState, tuple[int, ...]]:
+) -> StateMonad[ExpressionState, tuple[int, ...]]:
def func(state: ExpressionState):
state, polymatrix = expr.apply(state)
@@ -22,7 +21,7 @@ def shape(
def to_constant(
expr: Expression,
assert_constant: bool = True,
-) -> StateMonadMixin[ExpressionState, np.ndarray]:
+) -> StateMonad[ExpressionState, np.ndarray]:
def func(state: ExpressionState):
state, underlying = expr.apply(state)
@@ -43,7 +42,7 @@ def to_constant(
def to_sympy(
expr: Expression,
-) -> StateMonadMixin[ExpressionState, sympy.Expr | sympy.Matrix]:
+) -> StateMonad[ExpressionState, sympy.Expr | sympy.Matrix]:
def polymatrix_to_sympy(state: ExpressionState) -> tuple[ExpressionState, sympy.Expr | sympy.Matrix]:
diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py
index 58c5ca2..af762a0 100644
--- a/polymatrix/expression/typing.py
+++ b/polymatrix/expression/typing.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.statemonad.abc import StateMonad
+from polymatrix.statemonad import StateMonad
import numpy.typing as npt
import sympy
diff --git a/polymatrix/expressionstate.py b/polymatrix/expressionstate.py
index 931f282..7d0eb26 100644
--- a/polymatrix/expressionstate.py
+++ b/polymatrix/expressionstate.py
@@ -10,7 +10,7 @@ from polymatrix.variable.abc import Variable
from polymatrix.utils.deprecation import deprecated
from polymatrix.polymatrix.index import MonomialIndex, VariableIndex
-from polymatrix.statemonad.mixins import StateCacheMixin
+from polymatrix.statemonad import StateCacheMixin
# TODO: move to typing submodule
class IndexRange(NamedTuple):
diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad.py
index 5703241..dc680de 100644
--- a/polymatrix/statemonad/mixins.py
+++ b/polymatrix/statemonad.py
@@ -1,9 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from dataclasses import replace
+from dataclasses import dataclass, replace
from functools import wraps
-from typing import Callable, Tuple, TypeVar, Generic
+from typing import Callable, Tuple, TypeVar, Generic, Iterable, Any
class StateCacheMixin(ABC):
@@ -17,46 +17,29 @@ U = TypeVar("U")
V = TypeVar("V")
-# NP: Monadic type, use shorthand M for StateMonadMixin.
-# NP: (will use haskell-like notation in this comment)
-#
-# NP: typical operations for monad (do you agree wit this? If not please explain your conventions)
-#
-# NP: - unit operation (aka return, bad name) take a value `u :: U` and make a `m :: M[U]`
-# NP: you often call operation "from"
-#
-# NP: - map operation (aka lift) take a function (U -> V) and make a new function (M[U] -> M[V])
-#
-# NP: - bind operation (aka flat map) take a function (U -> M[V])
-# NP: and make a new function (M[U] -> M[V])
-# NP you call this operation flat_map
-#
-# NP: - apply operation take a M[U -> V] and make (M[U] -> M[V])
-#
-# NP: - zip operation take a function (U -> V -> W)
-# NP: and make a new function (M[U] -> M[V] -> M[W])
-#
-# NP: TODO: text comparing the above to implementation below
-class StateMonadMixin(Generic[State, U], ABC):
- @property
- @abstractmethod
- def apply_func(self) -> Callable[[State], tuple[State, U]]:
- # NP: TODO comment
- ...
+@dataclass
+class StateMonad(Generic[State, U]):
+ apply_func: Callable[[State], tuple[State, U]]
- @property
- @abstractmethod
- def arguments(self) -> U | None:
- # arguments that were given to the function apply_func.
- # this field is optional
+ # TODO: review this. It was added because I want to be able to see what
+ # was passed to the statemonads that are applied to expressions. For
+ # example in to_sympy, you want so see what expression is converted to
+ # sympy.
+ # arguments that were given to the function apply_func.
+ # this field is optional
+ arguments: U | None
- # TODO: review this. It was added because I want to be able to see what
- # was passed to the statemonads that are applied to expressions. For
- # example in to_sympy, you want so see what expression is converted to
- # sympy.
- ...
+ def __str__(self):
+ if not self.arguments:
+ return f"{str(self.apply_func.__name__)}(...)"
- def map(self, fn: Callable[[U], V]) -> StateMonadMixin[State, V]:
+ args = str(self.arguments)
+ if isinstance(self.arguments, Iterable):
+ args = ", ".join(map(str, self.arguments))
+
+ return f"{str(self.apply_func.__name__)}({args})"
+
+ def map(self, fn: Callable[[U], V]) -> StateMonad[State, V]:
@wraps(fn)
def internal_map(state: State) -> Tuple[State, U]:
n_state, val = self.apply(state)
@@ -64,7 +47,7 @@ class StateMonadMixin(Generic[State, U], ABC):
return replace(self, apply_func=internal_map)
- def flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]:
+ def flat_map(self, fn: Callable[[U], StateMonad[State, V]]) -> StateMonad[State, V]:
@wraps(fn)
def internal_map(state: State) -> Tuple[State, V]:
n_state, val = self.apply(state)
@@ -73,7 +56,7 @@ class StateMonadMixin(Generic[State, U], ABC):
return replace(self, apply_func=internal_map)
# FIXME: typing
- def zip(self, other: StateMonadMixin) -> StateMonadMixin:
+ def zip(self, other: StateMonad) -> StateMonad:
def internal_map(state: State) -> Tuple[State, V]:
state, val1 = self.apply(state)
state, val2 = other.apply(state)
@@ -82,7 +65,7 @@ class StateMonadMixin(Generic[State, U], ABC):
return replace(self, apply_func=internal_map)
# FIXME: typing
- def cache(self) -> StateMonadMixin:
+ def cache(self) -> StateMonad:
def internal_map(state: State) -> Tuple[State, V]:
if self in state.cache:
return state, state.cache[self]
@@ -107,3 +90,13 @@ class StateMonadMixin(Generic[State, U], ABC):
# NP: (I know what it does but the name is very vague)
def read(self, state: State) -> U:
return self.apply_func(state)[1]
+
+
+def init_state_monad(
+ apply_func: Callable,
+ arguments: Any | None = None
+):
+ return StateMonad(
+ apply_func=apply_func,
+ arguments=arguments,
+ )
diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py
deleted file mode 100644
index 48e369e..0000000
--- a/polymatrix/statemonad/__init__.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from polymatrix.statemonad.init import init_state_monad
-from polymatrix.statemonad.abc import StateMonad
-
-# NP: this is the unit operation for the monad, why not move it inside the
-# NP: monad class?
-def from_(val):
- def func(state):
- return state, val
-
- return init_state_monad(func)
-
-
-# NP: duplicate-ish with StateMonadMixin.zip, this one is more generic
-# NP: consider moving this into the mixin and deleting this function
-def zip(monads: tuple[StateMonad]):
- def zip_func(state):
- values = tuple()
-
- for monad in monads:
- state, val = monad.apply(state)
- values += (val,)
-
- return state, values
-
- return init_state_monad(zip_func)
diff --git a/polymatrix/statemonad/abc.py b/polymatrix/statemonad/abc.py
deleted file mode 100644
index 671b5bd..0000000
--- a/polymatrix/statemonad/abc.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import abc
-
-from polymatrix.statemonad.mixins import StateMonadMixin
-
-
-class StateMonad(StateMonadMixin, abc.ABC):
- pass
diff --git a/polymatrix/statemonad/impl.py b/polymatrix/statemonad/impl.py
deleted file mode 100644
index 817ce47..0000000
--- a/polymatrix/statemonad/impl.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from typing import Callable, Any, Iterable
-import dataclassabc
-
-from polymatrix.statemonad.abc import StateMonad
-
-
-@dataclassabc.dataclassabc(frozen=True)
-class StateMonadImpl(StateMonad):
- apply_func: Callable
- arguments: Any | None = None
-
- def __str__(self):
- if not self.arguments:
- return f"{str(self.apply_func.__name__)}(...)"
-
- args = str(self.arguments)
- if isinstance(self.arguments, Iterable):
- args = ", ".join(map(str, self.arguments))
-
- return f"{str(self.apply_func.__name__)}({args})"
diff --git a/polymatrix/statemonad/init.py b/polymatrix/statemonad/init.py
deleted file mode 100644
index 0abe6e4..0000000
--- a/polymatrix/statemonad/init.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from typing import Callable, Any
-
-from polymatrix.statemonad.impl import StateMonadImpl
-
-
-def init_state_monad(
- apply_func: Callable,
- arguments: Any | None = None
-):
- return StateMonadImpl(
- apply_func=apply_func,
- arguments=arguments,
- )