summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/impl.py10
-rw-r--r--polymatrix/expression/init.py8
-rw-r--r--polymatrix/expression/mixins/powerexprmixin.py67
3 files changed, 85 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index e638aca..429a0b2 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -48,6 +48,7 @@ from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMix
from polymatrix.expression.mixins.parametrizematrixexprmixin import (
ParametrizeMatrixExprMixin,
)
+from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin
from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin
from polymatrix.expression.mixins.quadraticmonomialsexprmixin import (
QuadraticMonomialsExprMixin,
@@ -315,6 +316,15 @@ class ParametrizeMatrixExprImpl(ParametrizeMatrixExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class PowerExprImpl(PowerExprMixin):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin | int | float
+
+ def __str__(self):
+ return f"({self.left} ** {self.right})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
class ProductExprImpl(ProductExprMixin):
underlying: tuple[ExpressionBaseMixin]
degrees: tuple[int, ...] | None
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 82d2f16..1e5baab 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -335,6 +335,14 @@ def init_parametrize_matrix_expr(
)
+def init_power_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+ stack: tuple[FrameSummary]
+):
+ return polymatrix.expression.impl.PowerExprImpl(left=left, right=right)
+
+
def init_quadratic_in_expr(
underlying: ExpressionBaseMixin,
monomials: ExpressionBaseMixin,
diff --git a/polymatrix/expression/mixins/powerexprmixin.py b/polymatrix/expression/mixins/powerexprmixin.py
new file mode 100644
index 0000000..ffc9d07
--- /dev/null
+++ b/polymatrix/expression/mixins/powerexprmixin.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+import math
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin
+from polymatrix.expressionstate.mixins import ExpressionStateMixin
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+
+
+class PowerExprMixin(ExpressionBaseMixin):
+ """
+ Raise an expression to an integral power. If the expression is not scalar,
+ it is interpreted as elementwise exponentiation.
+
+ The exponent may be an expression, but the expression must evaluate to an
+ integer constant.
+ """
+
+ @property
+ @abstractmethod
+ def left(self) -> ExpressionBaseMixin: ...
+
+ @property
+ @abstractmethod
+ def right(self) -> ExpressionBaseMixin | int | float: ...
+
+ @override
+ def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ exponent: int | None = None
+
+ # Right (exponent) must end up being a scalar constant
+ if isinstance(self.right, int):
+ exponent = self.right
+
+ elif isinstance(self.right, float):
+ exponent = int(self.right)
+ if not math.isclose(self.right - exponent, 0):
+ raise ValueError("Cannot raise a variable to a non-integral power. "
+ f"Exponent {self.right} (float) is not close enough to an integer.")
+
+ elif isinstance(self.right, ExpressionBaseMixin):
+ state, right_polymatrix = self.right.apply(state)
+ right = right_polymatrix.at(0, 0).constant()
+
+ if not isinstance(right, int):
+ exponent = int(right)
+ if not math.isclose(right - exponent, 0):
+ raise ValueError("Cannot raise a variable to a non-integral power. "
+ f"Exponent {right}, resulting from {self.right} is not an integer.")
+ else:
+ exponent = right
+
+ else:
+ raise TypeError(f"Cannot raise {self.left} to {self.right}, ",
+ f"because exponet has type {type(self.right)}")
+
+ state, left = self.left.apply(state)
+ result = left
+ for _ in range(exponent -1):
+ state, result = ElemMultExprMixin.elem_mult(state, result, left)
+
+ return state, result
+