From 736b1762d0e706c4a8b4ac7d7bd2f32cdbc88f80 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Tue, 7 May 2024 00:33:09 +0200 Subject: Improve PolyMatrixAsAffineExpr.affine_coefficient --- polymatrix/polymatrix/mixins.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index e13b868..81aa2dd 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -350,22 +350,24 @@ class PolyMatrixAsAffineExpressionMixin( """ return self.data - def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType | None: + def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType: r""" Get the affine coefficient :math:`A_\alpha` associated to :math:`x^\alpha`. """ - if monomial not in self.slices: - # FIXME: do not return none, but a falsy element of MatrixType - return None + if monomial not in self.slices.keys(): + nrows, _ = self.shape + return np.zeros((nrows, 1)) columns = range(*self.slices[monomial]) return self.data[:, columns] - def affine_coefficients_by_degrees(self) -> Iterable[tuple[int, MatrixType]]: + def affine_coefficients_by_degrees(self) -> Iterable[tuple[int, tuple[MonomialIndex], MatrixType]]: """ Iterate over the coefficients grouped by degree. """ groups = itertools.groupby(self.slices.keys(), lambda m: m.degree) for degree, monomials in groups: - yield degree, np.hstack(list(self.data[:, self.slices[m]] for m in monomials)) + monomials = tuple(monomials) + columns = np.hstack(list(self.data[:, self.slices[m]] for m in monomials)) + yield degree, monomials, columns def affine_coefficients_by_variable(self) -> Iterable[tuple[int, MatrixType]]: r""" @@ -383,7 +385,6 @@ class PolyMatrixAsAffineExpressionMixin( for v, monomials in groups: yield v, np.hstack(list(self.data[:, self.slices[m]] for m in monomials)) - def affine_eval(self, x: MatrixType) -> MatrixType: r""" Evaluate the affine expression @@ -410,7 +411,8 @@ class PolyMatrixAsAffineExpressionMixin( # structure to reduce the number of computations necessary. But the # efficiency gain is useless if most of the computed terms are not # used. - if len(self.slices) > n / 2: + # TODO: tune heuristic + if len(self.slices) > (n / 2): # Compute all powers of x up to degree d all_monomials = self.monomials_eval_all(x) monomials = { -- cgit v1.2.1