diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/polymatrix/mixins.py | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index bb2fc17..b35490d 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -63,6 +63,20 @@ class PolyMatrixMixin(ABC): if poly := self.at(row, col): yield MatrixIndex(row, col), poly + def variables(self) -> set[VariableIndex]: + """ + Get the indices of all variable indices that are in the polymatrix + object. Only scalar VariableIndices are returned, i.e. they all have a + power of one. + """ + variables = set() + for _, poly in self.entries(): + for monomial, _ in poly.terms(): + for var in monomial: + variables.add(VariableIndex(var.index, power=1)) + + return variables + # -- Old API --- @deprecated("Replaced by PolyMatrixMixin.entries()") @@ -124,7 +138,17 @@ class BroadcastPolyMatrixMixin(PolyMatrixMixin, ABC): @override def at(self, row: int, col: int) -> PolyDict: """ See :py:meth:`PolyMatrixMixin.at`. """ - return self.data + return self.polynomial + + @override + def variables(self) -> set[VariableIndex]: + # Override to be more efficient + variables = set() + for monomial, _ in self.poly.terms(): + for var in monomial: + variables.add(VariableIndex(var.index, power=1)) + + return variables # --- Old API --- |