From 573b7445866e1f870ebc1c0482114f1c19e2e7e0 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Thu, 9 May 2024 12:43:29 +0200
Subject: Create FromStateMonadMixin

---
 polymatrix/__init__.py                         | 32 +++++++++++++++------
 polymatrix/expression/from_.py                 |  4 +++
 polymatrix/expression/impl.py                  | 20 +++++++++++--
 polymatrix/expression/mixins/fromstatemonad.py | 39 ++++++++++++++++++++++++++
 4 files changed, 85 insertions(+), 10 deletions(-)
 create mode 100644 polymatrix/expression/mixins/fromstatemonad.py

diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index f68f364..94da22a 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -2,13 +2,26 @@ from polymatrix.expressionstate.abc import ExpressionState as internal_Expressio
 from polymatrix.expressionstate.init import (
     init_expression_state as internal_init_expression_state,
 )
-from polymatrix.expression.expression import Expression as internal_Expression
-from polymatrix.expression.from_ import from_ as internal_from, from_names as internal_from_names, from_name as internal_from_name
-from polymatrix.expression import v_stack as internal_v_stack
-from polymatrix.expression import h_stack as internal_h_stack
-from polymatrix.expression import product as internal_product
-from polymatrix.expression.to import to_constant as internal_to_constant
-from polymatrix.expression.to import to_sympy as internal_to_sympy
+
+from polymatrix.expression.from_ import (
+    from_ as internal_from,
+    from_names as internal_from_names,
+    from_name as internal_from_name,
+    from_statemonad as internal_from_statemonad,
+)
+
+from polymatrix.expression import (
+    Expression as internal_Expression,
+    v_stack as internal_v_stack,
+    h_stack as internal_h_stack,
+    product as internal_product,
+)
+
+from polymatrix.expression.to import (
+    to_constant as internal_to_constant,
+    to_sympy as internal_to_sympy
+)
+
 from polymatrix.denserepr.from_ import from_polymatrix
 from polymatrix.polymatrix.init import to_affine_expression
 
@@ -18,10 +31,10 @@ ExpressionState = internal_ExpressionState
 init_expression_state = internal_init_expression_state
 make_state = init_expression_state
 
-from_ = internal_from
 v_stack = internal_v_stack
 h_stack = internal_h_stack
 product = internal_product
+
 to_constant_repr = internal_to_constant
 to_constant = internal_to_constant
 to_sympy_repr = internal_to_sympy
@@ -29,5 +42,8 @@ to_sympy = internal_to_sympy
 to_matrix_repr = from_polymatrix
 to_dense = from_polymatrix
 to_affine = to_affine_expression
+
+from_ = internal_from
 from_names = internal_from_names
 from_name = internal_from_name
+from_statemonad = internal_from_statemonad
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 1e50c40..52f9d06 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -33,6 +33,10 @@ def from_(
     )
 
 
+def from_statemonad(monad: StateMonad):
+    return init_expression(polymatrix.expression.init.init_from_statemonad(monad))
+
+
 def from_names(names: str, shape: tuple[int, int] = (1,1)) -> Iterable[VariableExpression]:
     """ Construct one or multiple variables from comma separated a list of names. """
     for name in names.split(","):
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 0ddda2e..e638aca 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -1,10 +1,11 @@
 import numpy.typing
 import sympy
 from typing_extensions import override
+from polymatrix.statemonad.abc import StateMonad
 
 import dataclassabc
-from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin
 
+from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin
 from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin
 from polymatrix.expression.mixins.productexprmixin import ProductExprMixin
 from polymatrix.utils.getstacklines import FrameSummary
@@ -25,6 +26,7 @@ from polymatrix.expression.mixins.fromsymmetricmatrixexprmixin import (
 )
 from polymatrix.expression.mixins.fromnumbersexprmixin import FromNumbersExprMixin
 from polymatrix.expression.mixins.fromnumpyexprmixin import FromNumpyExprMixin
+from polymatrix.expression.mixins.fromstatemonad import FromStateMonadMixin
 from polymatrix.expression.mixins.fromsympyexprmixin import FromSympyExprMixin
 from polymatrix.expression.mixins.fromtermsexprmixin import (
     FromPolynomialDataExprMixin,
@@ -174,13 +176,24 @@ class FromNumbersExprImpl(FromNumbersExprMixin):
 class FromNumpyExprImpl(FromNumpyExprMixin):
     data: numpy.typing.NDArray
 
+    def __str__(self):
+        return f"from_numpy({self.data})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
+class FromStateMonadImpl(FromStateMonadMixin):
+    monad: StateMonad
+
+    def __str__(self):
+        return f"from_statemonad({self.monad})"
+
 
 @dataclassabc.dataclassabc(frozen=True)
 class FromSympyExprImpl(FromSympyExprMixin):
     data: sympy.Expr | sympy.Matrix | tuple[tuple[sympy.Expr]]
 
     def __str__(self):
-        return f"FromSympy({self.data})"
+        return f"from_sympy({self.data})"
 
 
 @dataclassabc.dataclassabc(frozen=True)
@@ -275,6 +288,9 @@ class DegreeExprImpl(DegreeExprMixin):
     def __repr__(self):
         return f"{self.__class__.__name__}(underlying={self.underlying})"
 
+    def __str__(self):
+        return f"degree({str(self.underlying)})"
+
 
 @dataclassabc.dataclassabc(frozen=True)
 class MaxExprImpl(MaxExprMixin):
diff --git a/polymatrix/expression/mixins/fromstatemonad.py b/polymatrix/expression/mixins/fromstatemonad.py
new file mode 100644
index 0000000..216b160
--- /dev/null
+++ b/polymatrix/expression/mixins/fromstatemonad.py
@@ -0,0 +1,39 @@
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.mixins import ExpressionStateMixin
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.statemonad.abc import StateMonad
+
+class FromStateMonadMixin(ExpressionBaseMixin):
+    """
+    Make a expression from a `StateMonad` object wrapping a function
+    that returns one of the following types:
+
+      - Expression
+      - ExpressionBaseMixin
+      - PolyMatrix
+    """
+
+    @property
+    @abstractmethod
+    def monad(self) -> StateMonad:
+        """ The state monad object. """
+
+    @override
+    def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+        state, expr = self.monad.apply(state)
+        if isinstance(expr, Expression):
+            return expr.underlying.apply(state)
+
+        elif isinstance(expr, ExpressionBaseMixin):
+            return expr.apply(state)
+
+        elif isinstance(expr, PolyMatrixMixin):
+            return state, expr
+
+        else:
+            raise TypeError(f"Return type of StateMonad object {self.monad} "
+                            "must be of type Expression or PolyMatrix!")
-- 
cgit v1.2.1