summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/mixins.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index bb2fc17..b35490d 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -63,6 +63,20 @@ class PolyMatrixMixin(ABC):
if poly := self.at(row, col):
yield MatrixIndex(row, col), poly
+ def variables(self) -> set[VariableIndex]:
+ """
+ Get the indices of all variable indices that are in the polymatrix
+ object. Only scalar VariableIndices are returned, i.e. they all have a
+ power of one.
+ """
+ variables = set()
+ for _, poly in self.entries():
+ for monomial, _ in poly.terms():
+ for var in monomial:
+ variables.add(VariableIndex(var.index, power=1))
+
+ return variables
+
# -- Old API ---
@deprecated("Replaced by PolyMatrixMixin.entries()")
@@ -124,7 +138,17 @@ class BroadcastPolyMatrixMixin(PolyMatrixMixin, ABC):
@override
def at(self, row: int, col: int) -> PolyDict:
""" See :py:meth:`PolyMatrixMixin.at`. """
- return self.data
+ return self.polynomial
+
+ @override
+ def variables(self) -> set[VariableIndex]:
+ # Override to be more efficient
+ variables = set()
+ for monomial, _ in self.poly.terms():
+ for var in monomial:
+ variables.add(VariableIndex(var.index, power=1))
+
+ return variables
# --- Old API ---