diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index 5d26c89..8d1145f 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -63,8 +63,7 @@ class ExpressionStateMixin( def register(self, var: Variable) -> ExpressionStateMixin: """ Create an index for a variable, but does not return the index. If you - want the index use - :py:meth:`index` + want the index range use :py:meth:`index` """ state, _ = self.index(var) return state @@ -111,6 +110,37 @@ class ExpressionStateMixin( raise IndexError(f"There is no variable with index {index}.") + def get_variable_from_variable_index(self, var: VariableIndex) -> Variable: + """ Get the variable object from the index contained in a `VariableIndex` """ + return self.get_variable(var.index) + + def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]: + """ Get all variable objects from the indices contained in a `MonomialIndex` """ + for v in monomial: + # FIXME: non-scalar variable will be yielded multiple times + yield self.get_variable_from_variable_index(v) + + def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any: + """ + Get a variable object given its name, or if there is no variable with + the given name return what is passed in the `if_not_present` argument. + """ + for v in self.indices.keys(): + if v.name == name: + return name + + return if_not_present + + def get_variable_from_name(self, name: str) -> Variable: + """ + Get a variable object given its name, raises KeyError if there is no + variable with the given name. + """ + if v := self.get_variable_from_name_or(name, False): + return v + + raise KeyError(f"There is no variable named {name}") + def get_name(self, index: int) -> str: """ Get the name of a variable given its index. """ for variable, (start, end) in self.indices.items(): |