summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/mixins.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index 4dcd6f7..95c875a 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -348,26 +348,38 @@ class PolyMatrixAsAffineExpressionMixin(
# Sort the values by monomial index, which have a total order
return dict(sorted(monomial_values.items()))
-
def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType:
r""" Get the affine coefficient :math:`A_\alpha` associated to :math:`x^\alpha`. """
if monomial not in self.slices.keys():
nrows, _ = self.shape
- return np.zeros((nrows, 1))
+ return np.zeros(self.shape)
columns = self.slices[monomial]
return self.affine_coefficients[:, columns]
- def affine_coefficients_by_degrees(self) -> Iterable[tuple[int, MatrixType]]:
+ def affine_coefficients_by_degrees(
+ self,
+ variables: Iterable[VariableIndex] | Iterable[MonomialIndex] | None = None
+ ) -> Iterable[tuple[int, MatrixType]]:
"""
Iterate over the coefficients grouped by degree.
"""
- # Get involved variables
- # TODO: this value should be cached
- variables = sorted(set(VariableIndex(v.index, power=1)
- for m in self.slices.keys()
- for v in m))
-
+ if not variables:
+ # Get involved variables
+ # TODO: this value should be cached
+ variables = tuple(set(VariableIndex(v.index, power=1)
+ for m in self.slices.keys()
+ for v in m))
+
+ # FIXME: allow passing linear monomial indices
+ # elif isinstance(variables[0], MonomialIndex):
+ # for v in variables:
+ # if v.degree != 1:
+ # raise ValueError("Variables ...")
+ # variables = map(lambda m: m[0], variables)
+
+
+ variables = sorted(variables)
for degree in range(self.degree +1):
monomials = MonomialIndex.combinations_of_degree(variables, degree)
columns = tuple(self.affine_coefficient(m) for m in monomials)