summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-09 16:22:43 +0200
committerNao Pross <np@0hm.ch>2024-05-09 16:26:23 +0200
commitd63df1c1bbc9b963154a656ef94353aa8e5e3f93 (patch)
treec3c5635398c687d137633931ed4ca3338e88a795
parentCreate helper PolyDict.is_constant to check if a polynomial is a constant (diff)
downloadpolymatrix-d63df1c1bbc9b963154a656ef94353aa8e5e3f93.tar.gz
polymatrix-d63df1c1bbc9b963154a656ef94353aa8e5e3f93.zip
Adapt CombinationsExprMixin to work with expression coefficients
-rw-r--r--polymatrix/expression/impl.py4
-rw-r--r--polymatrix/expression/mixins/combinationsexprmixin.py64
2 files changed, 55 insertions, 13 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 429a0b2..e6b021f 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -103,10 +103,10 @@ class CacheExprImpl(CacheExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class CombinationsExprImpl(CombinationsExprMixin):
expression: ExpressionBaseMixin
- degrees: tuple[int, ...]
+ degrees: ExpressionBaseMixin | tuple[int, ...]
def __str__(self):
- if len(self.degrees) == 1:
+ if isinstance(self.degrees, tuple) and len(self.degrees) == 1:
return f"combinations({self.expression}, {self.degrees[0]})"
return f"combinations({self.expression}, {self.degrees})"
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
index 56a1d64..121d71c 100644
--- a/polymatrix/expression/mixins/combinationsexprmixin.py
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -1,5 +1,6 @@
import abc
import itertools
+from typing import Iterable
from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
from polymatrix.polymatrix.init import init_poly_matrix
@@ -9,35 +10,76 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
class CombinationsExprMixin(ExpressionBaseMixin):
+ # FIXME: improve docstring
"""
combination using degrees=(0, 1, 2, 3):
[[x]] -> [[1], [x], [x**2], [x**3]]
"""
- # NP: example is not great, should be with x_1 and x_2 to show actual
- # NP: effect of combinations (or am I understanding this wrong?)
@property
@abc.abstractmethod
- def expression(self) -> ExpressionBaseMixin: ...
+ def expression(self) -> ExpressionBaseMixin:
+ """ Column vector. """
@property
@abc.abstractmethod
- def degrees(self) -> tuple[int, ...]: ...
+ def degrees(self) -> ExpressionBaseMixin | tuple[int, ...]:
+ """
+ Vector or scalar expression, or a list of integers.
+ """
- # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- state, poly_matrix = self.expression.apply(state=state)
-
- assert poly_matrix.shape[1] == 1
+ state, expr_pm = self.expression.apply(state)
+
+ degrees: Iterable | None = None
+
+ if isinstance(self.degrees, ExpressionBaseMixin):
+ state, deg_pm = self.degrees.apply(state)
+
+ # Check that it is a constant
+ for entry, poly in deg_pm.entries():
+ if not poly.is_constant():
+ # FIXME: improve error message
+ raise ValueError("Non-constant exponent resulting from "
+ f"evaluating {self.degrees}. Exponent must be a constant!")
+
+ # Scalars are OK
+ nrows, ncols = deg_pm.shape
+ if nrows == 1 and ncols == 1:
+ degrees = (deg_pm.at(0, 0).constant(),)
+
+ # Column vectors are OK
+ elif nrows == 1:
+ degrees = (deg_pm.at(i, 0).constant()
+ for i in range(nrows))
+
+ # Row vectors are OK
+ elif ncols == 1:
+ degrees = (deg_pm.at(0, i).constant()
+ for i in range(ncols))
+
+ # Matrices are not OK
+ else:
+ raise ValueError(f"Invalid exponent with shape {deg_pm.shape} resulting from {self.degrees} "
+ "Exponent can only be a scalar or a vector of exponents "
+ "(multi-index), matrices are not allowed.")
+
+ elif isinstance(self.degrees, tuple):
+ degrees = self.degrees
+
+ # TODO: check that degrees are all integers
+
+ # FIXME: improve error message
+ assert expr_pm.shape[1] == 1
def gen_indices():
- for degree in self.degrees:
+ for degree in degrees:
yield from itertools.combinations_with_replacement(
- range(poly_matrix.shape[0]), degree
+ range(expr_pm.shape[0]), degree
)
indices = tuple(gen_indices())
@@ -51,7 +93,7 @@ class CombinationsExprMixin(ExpressionBaseMixin):
continue
def acc_product(left, row):
- right = poly_matrix.get_poly(row, 0)
+ right = expr_pm.get_poly(row, 0)
if len(left) == 0:
return right