aboutsummaryrefslogtreecommitdiffstats
path: root/mdpoly/operations/exp.py
diff options
context:
space:
mode:
Diffstat (limited to 'mdpoly/operations/exp.py')
-rw-r--r--mdpoly/operations/exp.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py
index 0f57fcb..2b17a07 100644
--- a/mdpoly/operations/exp.py
+++ b/mdpoly/operations/exp.py
@@ -8,8 +8,10 @@ from dataclasses import dataclass
from ..abc import Expr, Const
from ..errors import AlgebraicError
+from ..index import Shape
from . import BinaryOp, Reducible
+from .mul import PolyMul
# TODO: implement matrix exponential, use caley-hamilton thm magic
@@ -22,12 +24,13 @@ class MatExp(BinaryOp, Reducible):
@dataclass(eq=False)
class PolyExp(BinaryOp, Reducible):
""" Exponentiation operator between scalar polynomials. """
+ shape: Shape = Shape.scalar()
@property # type: ignore[override]
def right(self) -> Const: # type: ignore[override]
if not isinstance(super().right, Const):
- raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because"
- f"{self.right} is not a constant.")
+ raise AlgebraicError(f"Cannot raise {str(self.left)} to {str(super().right)} "
+ f"because {str(super().right)} is not a constant.")
return cast(Const, super().right)
@@ -44,7 +47,7 @@ class PolyExp(BinaryOp, Reducible):
raise NotImplementedError("Cannot raise to non-integer powers (yet).")
ntimes = self.right.value - 1
- return reduce(opmul, (var for _ in range(ntimes)), var)
+ return reduce(PolyMul, (var for _ in range(ntimes)), var)
def __str__(self) -> str:
return f"({self.left} ** {self.right})"