diff options
Diffstat (limited to 'mdpoly/operations/exp.py')
-rw-r--r-- | mdpoly/operations/exp.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py index 0f57fcb..2b17a07 100644 --- a/mdpoly/operations/exp.py +++ b/mdpoly/operations/exp.py @@ -8,8 +8,10 @@ from dataclasses import dataclass from ..abc import Expr, Const from ..errors import AlgebraicError +from ..index import Shape from . import BinaryOp, Reducible +from .mul import PolyMul # TODO: implement matrix exponential, use caley-hamilton thm magic @@ -22,12 +24,13 @@ class MatExp(BinaryOp, Reducible): @dataclass(eq=False) class PolyExp(BinaryOp, Reducible): """ Exponentiation operator between scalar polynomials. """ + shape: Shape = Shape.scalar() @property # type: ignore[override] def right(self) -> Const: # type: ignore[override] if not isinstance(super().right, Const): - raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because" - f"{self.right} is not a constant.") + raise AlgebraicError(f"Cannot raise {str(self.left)} to {str(super().right)} " + f"because {str(super().right)} is not a constant.") return cast(Const, super().right) @@ -44,7 +47,7 @@ class PolyExp(BinaryOp, Reducible): raise NotImplementedError("Cannot raise to non-integer powers (yet).") ntimes = self.right.value - 1 - return reduce(opmul, (var for _ in range(ntimes)), var) + return reduce(PolyMul, (var for _ in range(ntimes)), var) def __str__(self) -> str: return f"({self.left} ** {self.right})" |