diff options
author | Nao Pross <np@0hm.ch> | 2024-06-02 15:40:14 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-06-02 15:40:14 +0200 |
commit | bf30f234c95fa33265b2ad0334f7e0932815bb11 (patch) | |
tree | c5210156c3d0d3586a1d9feec4523d25d8a86649 | |
parent | Create helper PolyDict.terms_of_degree(d) (diff) | |
download | polymatrix-bf30f234c95fa33265b2ad0334f7e0932815bb11.tar.gz polymatrix-bf30f234c95fa33265b2ad0334f7e0932815bb11.zip |
Introduce PolyMatrix.variables() to retrieve all variables
Diffstat (limited to '')
-rw-r--r-- | polymatrix/polymatrix/mixins.py | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index bb2fc17..b35490d 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -63,6 +63,20 @@ class PolyMatrixMixin(ABC): if poly := self.at(row, col): yield MatrixIndex(row, col), poly + def variables(self) -> set[VariableIndex]: + """ + Get the indices of all variable indices that are in the polymatrix + object. Only scalar VariableIndices are returned, i.e. they all have a + power of one. + """ + variables = set() + for _, poly in self.entries(): + for monomial, _ in poly.terms(): + for var in monomial: + variables.add(VariableIndex(var.index, power=1)) + + return variables + # -- Old API --- @deprecated("Replaced by PolyMatrixMixin.entries()") @@ -124,7 +138,17 @@ class BroadcastPolyMatrixMixin(PolyMatrixMixin, ABC): @override def at(self, row: int, col: int) -> PolyDict: """ See :py:meth:`PolyMatrixMixin.at`. """ - return self.data + return self.polynomial + + @override + def variables(self) -> set[VariableIndex]: + # Override to be more efficient + variables = set() + for monomial, _ in self.poly.terms(): + for var in monomial: + variables.add(VariableIndex(var.index, power=1)) + + return variables # --- Old API --- |