summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-11 16:52:42 +0200
committerNao Pross <np@0hm.ch>2024-05-11 17:36:47 +0200
commit6bca5974ed1da7f334f305f0c2221ba45e933b90 (patch)
tree0102ab3c7c48cce060b3c16b96ed0fb42a122268
parentClean up PolyMatrixAsAffineExpression (diff)
downloadpolymatrix-6bca5974ed1da7f334f305f0c2221ba45e933b90.tar.gz
polymatrix-6bca5974ed1da7f334f305f0c2221ba45e933b90.zip
Add more helper methods to ExpressionState
-rw-r--r--polymatrix/expressionstate/mixins.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index 5d26c89..8d1145f 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -63,8 +63,7 @@ class ExpressionStateMixin(
def register(self, var: Variable) -> ExpressionStateMixin:
"""
Create an index for a variable, but does not return the index. If you
- want the index use
- :py:meth:`index`
+ want the index range use :py:meth:`index`
"""
state, _ = self.index(var)
return state
@@ -111,6 +110,37 @@ class ExpressionStateMixin(
raise IndexError(f"There is no variable with index {index}.")
+ def get_variable_from_variable_index(self, var: VariableIndex) -> Variable:
+ """ Get the variable object from the index contained in a `VariableIndex` """
+ return self.get_variable(var.index)
+
+ def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]:
+ """ Get all variable objects from the indices contained in a `MonomialIndex` """
+ for v in monomial:
+ # FIXME: non-scalar variable will be yielded multiple times
+ yield self.get_variable_from_variable_index(v)
+
+ def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any:
+ """
+ Get a variable object given its name, or if there is no variable with
+ the given name return what is passed in the `if_not_present` argument.
+ """
+ for v in self.indices.keys():
+ if v.name == name:
+ return name
+
+ return if_not_present
+
+ def get_variable_from_name(self, name: str) -> Variable:
+ """
+ Get a variable object given its name, raises KeyError if there is no
+ variable with the given name.
+ """
+ if v := self.get_variable_from_name_or(name, False):
+ return v
+
+ raise KeyError(f"There is no variable named {name}")
+
def get_name(self, index: int) -> str:
""" Get the name of a variable given its index. """
for variable, (start, end) in self.indices.items():