diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/impl.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 8 | ||||
-rw-r--r-- | polymatrix/expression/mixins/powerexprmixin.py | 67 |
3 files changed, 85 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index e638aca..429a0b2 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -48,6 +48,7 @@ from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMix from polymatrix.expression.mixins.parametrizematrixexprmixin import ( ParametrizeMatrixExprMixin, ) +from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin from polymatrix.expression.mixins.quadraticmonomialsexprmixin import ( QuadraticMonomialsExprMixin, @@ -315,6 +316,15 @@ class ParametrizeMatrixExprImpl(ParametrizeMatrixExprMixin): @dataclassabc.dataclassabc(frozen=True) +class PowerExprImpl(PowerExprMixin): + left: ExpressionBaseMixin + right: ExpressionBaseMixin | int | float + + def __str__(self): + return f"({self.left} ** {self.right})" + + +@dataclassabc.dataclassabc(frozen=True) class ProductExprImpl(ProductExprMixin): underlying: tuple[ExpressionBaseMixin] degrees: tuple[int, ...] | None diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 82d2f16..1e5baab 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -335,6 +335,14 @@ def init_parametrize_matrix_expr( ) +def init_power_expr( + left: ExpressionBaseMixin, + right: ExpressionBaseMixin, + stack: tuple[FrameSummary] +): + return polymatrix.expression.impl.PowerExprImpl(left=left, right=right) + + def init_quadratic_in_expr( underlying: ExpressionBaseMixin, monomials: ExpressionBaseMixin, diff --git a/polymatrix/expression/mixins/powerexprmixin.py b/polymatrix/expression/mixins/powerexprmixin.py new file mode 100644 index 0000000..ffc9d07 --- /dev/null +++ b/polymatrix/expression/mixins/powerexprmixin.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import math + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin +from polymatrix.expressionstate.mixins import ExpressionStateMixin +from polymatrix.polymatrix.mixins import PolyMatrixMixin + + +class PowerExprMixin(ExpressionBaseMixin): + """ + Raise an expression to an integral power. If the expression is not scalar, + it is interpreted as elementwise exponentiation. + + The exponent may be an expression, but the expression must evaluate to an + integer constant. + """ + + @property + @abstractmethod + def left(self) -> ExpressionBaseMixin: ... + + @property + @abstractmethod + def right(self) -> ExpressionBaseMixin | int | float: ... + + @override + def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + exponent: int | None = None + + # Right (exponent) must end up being a scalar constant + if isinstance(self.right, int): + exponent = self.right + + elif isinstance(self.right, float): + exponent = int(self.right) + if not math.isclose(self.right - exponent, 0): + raise ValueError("Cannot raise a variable to a non-integral power. " + f"Exponent {self.right} (float) is not close enough to an integer.") + + elif isinstance(self.right, ExpressionBaseMixin): + state, right_polymatrix = self.right.apply(state) + right = right_polymatrix.at(0, 0).constant() + + if not isinstance(right, int): + exponent = int(right) + if not math.isclose(right - exponent, 0): + raise ValueError("Cannot raise a variable to a non-integral power. " + f"Exponent {right}, resulting from {self.right} is not an integer.") + else: + exponent = right + + else: + raise TypeError(f"Cannot raise {self.left} to {self.right}, ", + f"because exponet has type {type(self.right)}") + + state, left = self.left.apply(state) + result = left + for _ in range(exponent -1): + state, result = ElemMultExprMixin.elem_mult(state, result, left) + + return state, result + |