From b8ce1f116a9b2ddaada2b33ce4269e4dc5159099 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Wed, 8 May 2024 01:12:05 +0200 Subject: Fix bug in PolyMatrixAsAffineExpr.affine_coefficient --- polymatrix/polymatrix/mixins.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index 4dcd6f7..95c875a 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -348,26 +348,38 @@ class PolyMatrixAsAffineExpressionMixin( # Sort the values by monomial index, which have a total order return dict(sorted(monomial_values.items())) - def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType: r""" Get the affine coefficient :math:`A_\alpha` associated to :math:`x^\alpha`. """ if monomial not in self.slices.keys(): nrows, _ = self.shape - return np.zeros((nrows, 1)) + return np.zeros(self.shape) columns = self.slices[monomial] return self.affine_coefficients[:, columns] - def affine_coefficients_by_degrees(self) -> Iterable[tuple[int, MatrixType]]: + def affine_coefficients_by_degrees( + self, + variables: Iterable[VariableIndex] | Iterable[MonomialIndex] | None = None + ) -> Iterable[tuple[int, MatrixType]]: """ Iterate over the coefficients grouped by degree. """ - # Get involved variables - # TODO: this value should be cached - variables = sorted(set(VariableIndex(v.index, power=1) - for m in self.slices.keys() - for v in m)) - + if not variables: + # Get involved variables + # TODO: this value should be cached + variables = tuple(set(VariableIndex(v.index, power=1) + for m in self.slices.keys() + for v in m)) + + # FIXME: allow passing linear monomial indices + # elif isinstance(variables[0], MonomialIndex): + # for v in variables: + # if v.degree != 1: + # raise ValueError("Variables ...") + # variables = map(lambda m: m[0], variables) + + + variables = sorted(variables) for degree in range(self.degree +1): monomials = MonomialIndex.combinations_of_degree(variables, degree) columns = tuple(self.affine_coefficient(m) for m in monomials) -- cgit v1.2.1