From d63df1c1bbc9b963154a656ef94353aa8e5e3f93 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Thu, 9 May 2024 16:22:43 +0200 Subject: Adapt CombinationsExprMixin to work with expression coefficients --- polymatrix/expression/impl.py | 4 +- .../expression/mixins/combinationsexprmixin.py | 64 ++++++++++++++++++---- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 429a0b2..e6b021f 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -103,10 +103,10 @@ class CacheExprImpl(CacheExprMixin): @dataclassabc.dataclassabc(frozen=True) class CombinationsExprImpl(CombinationsExprMixin): expression: ExpressionBaseMixin - degrees: tuple[int, ...] + degrees: ExpressionBaseMixin | tuple[int, ...] def __str__(self): - if len(self.degrees) == 1: + if isinstance(self.degrees, tuple) and len(self.degrees) == 1: return f"combinations({self.expression}, {self.degrees[0]})" return f"combinations({self.expression}, {self.degrees})" diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 56a1d64..121d71c 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -1,5 +1,6 @@ import abc import itertools +from typing import Iterable from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix @@ -9,35 +10,76 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin class CombinationsExprMixin(ExpressionBaseMixin): + # FIXME: improve docstring """ combination using degrees=(0, 1, 2, 3): [[x]] -> [[1], [x], [x**2], [x**3]] """ - # NP: example is not great, should be with x_1 and x_2 to show actual - # NP: effect of combinations (or am I understanding this wrong?) @property @abc.abstractmethod - def expression(self) -> ExpressionBaseMixin: ... + def expression(self) -> ExpressionBaseMixin: + """ Column vector. """ @property @abc.abstractmethod - def degrees(self) -> tuple[int, ...]: ... + def degrees(self) -> ExpressionBaseMixin | tuple[int, ...]: + """ + Vector or scalar expression, or a list of integers. + """ - # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - state, poly_matrix = self.expression.apply(state=state) - - assert poly_matrix.shape[1] == 1 + state, expr_pm = self.expression.apply(state) + + degrees: Iterable | None = None + + if isinstance(self.degrees, ExpressionBaseMixin): + state, deg_pm = self.degrees.apply(state) + + # Check that it is a constant + for entry, poly in deg_pm.entries(): + if not poly.is_constant(): + # FIXME: improve error message + raise ValueError("Non-constant exponent resulting from " + f"evaluating {self.degrees}. Exponent must be a constant!") + + # Scalars are OK + nrows, ncols = deg_pm.shape + if nrows == 1 and ncols == 1: + degrees = (deg_pm.at(0, 0).constant(),) + + # Column vectors are OK + elif nrows == 1: + degrees = (deg_pm.at(i, 0).constant() + for i in range(nrows)) + + # Row vectors are OK + elif ncols == 1: + degrees = (deg_pm.at(0, i).constant() + for i in range(ncols)) + + # Matrices are not OK + else: + raise ValueError(f"Invalid exponent with shape {deg_pm.shape} resulting from {self.degrees} " + "Exponent can only be a scalar or a vector of exponents " + "(multi-index), matrices are not allowed.") + + elif isinstance(self.degrees, tuple): + degrees = self.degrees + + # TODO: check that degrees are all integers + + # FIXME: improve error message + assert expr_pm.shape[1] == 1 def gen_indices(): - for degree in self.degrees: + for degree in degrees: yield from itertools.combinations_with_replacement( - range(poly_matrix.shape[0]), degree + range(expr_pm.shape[0]), degree ) indices = tuple(gen_indices()) @@ -51,7 +93,7 @@ class CombinationsExprMixin(ExpressionBaseMixin): continue def acc_product(left, row): - right = poly_matrix.get_poly(row, 0) + right = expr_pm.get_poly(row, 0) if len(left) == 0: return right -- cgit v1.2.1