diff options
-rw-r--r-- | mdpoly/__init__.py | 4 | ||||
-rw-r--r-- | mdpoly/expressions.py | 98 | ||||
-rw-r--r-- | mdpoly/operations/exp.py | 9 |
3 files changed, 84 insertions, 27 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index 26025ab..af44d8a 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -114,6 +114,7 @@ from .expressions import (WithOps as _WithOps, from .state import State as _State from typing import Self, Iterable +from dataclasses import dataclass # ┏━╸╻ ╻┏━┓┏━┓┏━┓╺┳╸ ┏━┓┏━┓ ╻┏━┓ @@ -139,18 +140,21 @@ class FromHelpers: yield from map(cls, names) +@dataclass class Constant(_WithOps, FromHelpers): """ Constant values """ def __init__(self, *args, **kwargs): _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs)) +@dataclass class Variable(_WithOps, FromHelpers): """ Polynomial Variable """ def __init__(self, *args, **kwargs): _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs)) +@dataclass class Parameter(_WithOps, FromHelpers): """ Parameter that can be substituted """ def __init__(self, *args, **kwargs): diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index c34a0aa..57f015b 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -7,12 +7,15 @@ from functools import wraps from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast from .abc import Expr, Var, Const, Param -from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex -from .errors import MissingParameters +from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number +from .errors import MissingParameters, AlgebraicError + +from .operations.add import MatAdd, MatSub +from .operations.mul import MatElemMul +from .operations.exp import PolyExp if TYPE_CHECKING: from .abc import ReprT - from .index import Shape, Number from .state import State @@ -151,6 +154,17 @@ class WithOps: overloading and operations to expression objects. """ expr: Expr + # -- Magic methods --- + + def __enter__(self) -> Expr: + return self.expr + + def __exit__(self, *ex): + return False # Propagate exceptions + + def __str__(self): + return f"WithOps{str(self.expr)}" + # -- Monadic operations -- @staticmethod @@ -179,8 +193,7 @@ class WithOps: @staticmethod def bind(fn: Callable[[Expr], WithOps]) -> Callable[[WithOps], WithOps]: - """ Bind in the functional programming sense. Also sometimes known as - flatmap. + """ Bind in the functional programming sense. Also known as flatmap. Converts a funciton like `(Expr -> WithOps)` into `(WithOps -> WithOps)` so that it can be composed with other functions. @@ -190,37 +203,74 @@ class WithOps: return fn(arg.expr) return wrapper + @staticmethod + def wrap_result(meth: Callable[[WithOps, Any], Expr]) -> Callable[[WithOps, WithOps], WithOps]: + @wraps(meth) + def meth_wrapper(self, *args, **kwargs) -> WithOps: + return WithOps(expr=meth(self, *args, **kwargs)) + return meth_wrapper + # -- Operator overloading --- - def __add__(self, other: WithOps | Number) -> WithOps: - if isinstance(other, Number): - other = WithOps(expr=PolyConst(value=other)) + @classmethod + def _ensure_is_withops(cls, obj: WithOps | Expr | Number) -> WithOps: + if isinstance(obj, WithOps): + return obj - return add(self, other) + if isinstance(obj, Expr): + return WithOps(expr=obj) - def __radd__(self, other: WithOps | Number) -> WithOps: - if isinstance(other, Number): - other = WithOps(expr=PolyConst(value=other)) + # mypy typechecker bug: https://github.com/python/mypy/issues/11673 + elif isinstance(obj, Number): # type: ignore[misc, arg-type] + return WithOps(expr=PolyConst(value=obj)) # type: ignore[call-arg] - return add(other, self) + raise TypeError(f"Cannot wrap {obj}!") # FIXME: Decent message - def __sub__(self, other: Self) -> Self: - raise NotImplementedError + @wrap_result + def __add__(self, other: WithOps | Number) -> Expr: + other = self._ensure_is_withops(other) + with self as left, other as right: + return MatAdd(left, right) - def __rsub__(self, other: Self) -> Self: - raise NotImplementedError + @wrap_result + def __radd__(self, num: Number) -> Expr: + other = self._ensure_is_withops(num) + with other as left, self as right: + return MatAdd(left, right) - def __neg__(self) -> Self: - raise NotImplementedError + @wrap_result + def __sub__(self, other: WithOps | Number) -> Expr: + other = self._ensure_is_withops(other) + with self as left, other as right: + return MatSub(left, right) - def __mul__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __rsub__(self, num: Number) -> Expr: + other = self._ensure_is_withops(num) + with other as left, self as right: + return MatSub(left, right) - def __rmul__(self, other: Any) -> Self: + def __neg__(self) -> Self: raise NotImplementedError - def __pow__(self, other: Any) -> Self: - raise NotImplementedError + @wrap_result + def __mul__(self, other: WithOps | Number) -> Expr: + other = self._ensure_is_withops(other) + with self as left, other as right: + return MatElemMul(left, right) + + @wrap_result + def __rmul__(self, num: Number) -> Expr: + other = self._ensure_is_withops(num) + with other as left, self as right: + return MatElemMul(left, right) + + @wrap_result + def __pow__(self, num: Number) -> Expr: + other = self._ensure_is_withops(num) + with self as left, other as right: + # TODO: what if it is a matrix? Elementwise power? + return PolyExp(left, right) def __matmul__(self, other: Any) -> Self: raise NotImplementedError diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py index 0f57fcb..2b17a07 100644 --- a/mdpoly/operations/exp.py +++ b/mdpoly/operations/exp.py @@ -8,8 +8,10 @@ from dataclasses import dataclass from ..abc import Expr, Const from ..errors import AlgebraicError +from ..index import Shape from . import BinaryOp, Reducible +from .mul import PolyMul # TODO: implement matrix exponential, use caley-hamilton thm magic @@ -22,12 +24,13 @@ class MatExp(BinaryOp, Reducible): @dataclass(eq=False) class PolyExp(BinaryOp, Reducible): """ Exponentiation operator between scalar polynomials. """ + shape: Shape = Shape.scalar() @property # type: ignore[override] def right(self) -> Const: # type: ignore[override] if not isinstance(super().right, Const): - raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because" - f"{self.right} is not a constant.") + raise AlgebraicError(f"Cannot raise {str(self.left)} to {str(super().right)} " + f"because {str(super().right)} is not a constant.") return cast(Const, super().right) @@ -44,7 +47,7 @@ class PolyExp(BinaryOp, Reducible): raise NotImplementedError("Cannot raise to non-integer powers (yet).") ntimes = self.right.value - 1 - return reduce(opmul, (var for _ in range(ntimes)), var) + return reduce(PolyMul, (var for _ in range(ntimes)), var) def __str__(self) -> str: return f"({self.left} ** {self.right})" |