summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/mixins.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index cbd9749..76c31c9 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -204,7 +204,10 @@ class PolyMatrixAsAffineExpressionMixin(
@property
@abc.abstractmethod
def slices(self) -> dict[MonomialIndex, tuple[int, int]]:
- # column slices of the big array stored in self.data to get the individual A_alphas
+ r"""
+ Map from monomial indices to column slices of the big matrix that
+ stores all :math:`A_\alpha`.
+ """
...
@override
@@ -324,8 +327,19 @@ class PolyMatrixAsAffineExpressionMixin(
# Sort the values by monomial index, which have a total order
return dict(sorted(monomial_values.items()))
- def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType:
+ def affine_coefficients(self) -> MatrixType:
+ r"""
+ Get the large matrix
+ :math:`A = \begin{bmatrix} A_{\alpha_1} & A_{\alpha_2} & \cdots & A_{\alpha_N} \end{bmatrix}`
+ """
+ return self.data
+
+ def affine_coefficient(self, monomial: MonomialIndex) -> MatrixType | None:
r""" Get the affine coefficient :math:`A_\alpha` associated to :math:`x^\alpha`. """
+ if monomial not in self.slices:
+ # FIXME: do not return none, but a falsy element of MatrixType
+ return None
+
columns = range(*self.slices[monomial])
return self.data[:, columns]
@@ -348,7 +362,7 @@ class PolyMatrixAsAffineExpressionMixin(
monomials = np.array(tuple(self.monomials_eval(tuple(x)).items()))
# Evaluate the affine expression
- nrows, ncols = self.shape
+ _, ncols = self.shape
return self.data @ (np.kron(monomials, np.eye(ncols)))
def affine_eval_fn(self) -> Callable[[MatrixType], MatrixType]:
@@ -357,6 +371,5 @@ class PolyMatrixAsAffineExpressionMixin(
:math:`p(x) = \sum_{\alpha \in \mathcal{E}\langle x \rangle_d} A_\alpha x^\alpha`
at :math:`x`.
"""
- # TODO: docstring
# TODO: If slow consider replacing with toolz.functoolz.curry from ctoolz
return functools.partial(PolyMatrixAsAffineExpression.affine_eval, self)