summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/to.py2
-rw-r--r--polymatrix/statemonad/impl.py13
-rw-r--r--polymatrix/statemonad/init.py4
-rw-r--r--polymatrix/statemonad/mixins.py12
4 files changed, 28 insertions, 3 deletions
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index 2707fb4..b1de626 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -76,4 +76,4 @@ def to_sympy(
return state, m[0, 0]
return state, m
- return init_state_monad(polymatrix_to_sympy)
+ return init_state_monad(polymatrix_to_sympy, expr)
diff --git a/polymatrix/statemonad/impl.py b/polymatrix/statemonad/impl.py
index 14ad503..1b78daa 100644
--- a/polymatrix/statemonad/impl.py
+++ b/polymatrix/statemonad/impl.py
@@ -1,4 +1,4 @@
-from typing import Callable
+from typing import Callable, Any, Iterable
import dataclassabc
from polymatrix.statemonad.abc import StateMonad
@@ -7,3 +7,14 @@ from polymatrix.statemonad.abc import StateMonad
@dataclassabc.dataclassabc(frozen=True)
class StateMonadImpl(StateMonad):
apply_func: Callable
+ arguments: Any | None = None
+
+ def __str__(self):
+ if not self.arguments:
+ return super().__str__(self)
+
+ args = str(self.arguments)
+ if isinstance(self.arguments, Iterable):
+ args = ", ".join(map(str, self.arguments))
+
+ return f"{str(self.apply_func.__name__)}({args})"
diff --git a/polymatrix/statemonad/init.py b/polymatrix/statemonad/init.py
index f3a8bd6..0abe6e4 100644
--- a/polymatrix/statemonad/init.py
+++ b/polymatrix/statemonad/init.py
@@ -1,11 +1,13 @@
-from typing import Callable
+from typing import Callable, Any
from polymatrix.statemonad.impl import StateMonadImpl
def init_state_monad(
apply_func: Callable,
+ arguments: Any | None = None
):
return StateMonadImpl(
apply_func=apply_func,
+ arguments=arguments,
)
diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py
index e2bfda9..68d1db6 100644
--- a/polymatrix/statemonad/mixins.py
+++ b/polymatrix/statemonad/mixins.py
@@ -44,6 +44,18 @@ class StateMonadMixin(
# NP: TODO comment
...
+ @property
+ @abc.abstractmethod
+ def arguments(self) -> U | None:
+ # arguments that were given to the function apply_func.
+ # this field is optional
+
+ # TODO: review this. It was added because I want to be able to see what
+ # was passed to the statemonads that are applied to expressions. For
+ # example in to_sympy, you want so see what expression is converted to
+ # sympy.
+ ...
+
# NP: typing, use from __future__ import annotations
def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]':
# NP: add functools.wrap(fn) decorator to copy docstrings etc.