From 6bca5974ed1da7f334f305f0c2221ba45e933b90 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Sat, 11 May 2024 16:52:42 +0200
Subject: Add more helper methods to ExpressionState

---
 polymatrix/expressionstate/mixins.py | 34 ++++++++++++++++++++++++++++++++--
 1 file changed, 32 insertions(+), 2 deletions(-)

diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 5d26c89..8d1145f 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -63,8 +63,7 @@ class ExpressionStateMixin(
     def register(self, var: Variable) -> ExpressionStateMixin:
         """
         Create an index for a variable, but does not return the index. If you
-        want the index use
-        :py:meth:`index`
+        want the index range use :py:meth:`index`
         """
         state, _ = self.index(var)
         return state
@@ -111,6 +110,37 @@ class ExpressionStateMixin(
 
         raise IndexError(f"There is no variable with index {index}.")
 
+    def get_variable_from_variable_index(self, var: VariableIndex) -> Variable:
+        """ Get the variable object from the index contained in a `VariableIndex` """
+        return self.get_variable(var.index)
+
+    def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]:
+        """ Get all variable objects from the indices contained in a `MonomialIndex` """
+        for v in monomial:
+            # FIXME: non-scalar variable will be yielded multiple times
+            yield self.get_variable_from_variable_index(v)
+
+    def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any:
+        """
+        Get a variable object given its name, or if there is no variable with
+        the given name return what is passed in the `if_not_present` argument.
+        """
+        for v in self.indices.keys():
+            if v.name == name:
+                return name
+
+        return if_not_present
+
+    def get_variable_from_name(self, name: str) -> Variable:
+        """
+        Get a variable object given its name, raises KeyError if there is no
+        variable with the given name.
+        """
+        if v := self.get_variable_from_name_or(name, False):
+            return v
+
+        raise KeyError(f"There is no variable named {name}")
+
     def get_name(self, index: int) -> str:
         """ Get the name of a variable given its index. """
         for variable, (start, end) in self.indices.items():
-- 
cgit v1.2.1