aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-18 21:06:14 +0100
committerNao Pross <np@0hm.ch>2024-03-18 21:06:14 +0100
commit61e004224c9915cd92c8a7bed4649107ed311152 (patch)
tree7fab17df2c9c299e3b6e59d838f2d9cbf529f255
parentMove operations (operator overloading etc) outside of expression class (diff)
downloadmdpoly-61e004224c9915cd92c8a7bed4649107ed311152.tar.gz
mdpoly-61e004224c9915cd92c8a7bed4649107ed311152.zip
Re-implement __add__, __sub__, __mul__
-rw-r--r--mdpoly/__init__.py4
-rw-r--r--mdpoly/expressions.py98
-rw-r--r--mdpoly/operations/exp.py9
3 files changed, 84 insertions, 27 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 26025ab..af44d8a 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -114,6 +114,7 @@ from .expressions import (WithOps as _WithOps,
from .state import State as _State
from typing import Self, Iterable
+from dataclasses import dataclass
# ┏━╸╻ ╻┏━┓┏━┓┏━┓╺┳╸ ┏━┓┏━┓ ╻┏━┓
@@ -139,18 +140,21 @@ class FromHelpers:
yield from map(cls, names)
+@dataclass
class Constant(_WithOps, FromHelpers):
""" Constant values """
def __init__(self, *args, **kwargs):
_WithOps.__init__(self, expr=_PolyConst(*args, **kwargs))
+@dataclass
class Variable(_WithOps, FromHelpers):
""" Polynomial Variable """
def __init__(self, *args, **kwargs):
_WithOps.__init__(self, expr=_PolyVar(*args, **kwargs))
+@dataclass
class Parameter(_WithOps, FromHelpers):
""" Parameter that can be substituted """
def __init__(self, *args, **kwargs):
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index c34a0aa..57f015b 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -7,12 +7,15 @@ from functools import wraps
from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast
from .abc import Expr, Var, Const, Param
-from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex
-from .errors import MissingParameters
+from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
+from .errors import MissingParameters, AlgebraicError
+
+from .operations.add import MatAdd, MatSub
+from .operations.mul import MatElemMul
+from .operations.exp import PolyExp
if TYPE_CHECKING:
from .abc import ReprT
- from .index import Shape, Number
from .state import State
@@ -151,6 +154,17 @@ class WithOps:
overloading and operations to expression objects. """
expr: Expr
+ # -- Magic methods ---
+
+ def __enter__(self) -> Expr:
+ return self.expr
+
+ def __exit__(self, *ex):
+ return False # Propagate exceptions
+
+ def __str__(self):
+ return f"WithOps{str(self.expr)}"
+
# -- Monadic operations --
@staticmethod
@@ -179,8 +193,7 @@ class WithOps:
@staticmethod
def bind(fn: Callable[[Expr], WithOps]) -> Callable[[WithOps], WithOps]:
- """ Bind in the functional programming sense. Also sometimes known as
- flatmap.
+ """ Bind in the functional programming sense. Also known as flatmap.
Converts a funciton like `(Expr -> WithOps)` into `(WithOps -> WithOps)`
so that it can be composed with other functions.
@@ -190,37 +203,74 @@ class WithOps:
return fn(arg.expr)
return wrapper
+ @staticmethod
+ def wrap_result(meth: Callable[[WithOps, Any], Expr]) -> Callable[[WithOps, WithOps], WithOps]:
+ @wraps(meth)
+ def meth_wrapper(self, *args, **kwargs) -> WithOps:
+ return WithOps(expr=meth(self, *args, **kwargs))
+ return meth_wrapper
+
# -- Operator overloading ---
- def __add__(self, other: WithOps | Number) -> WithOps:
- if isinstance(other, Number):
- other = WithOps(expr=PolyConst(value=other))
+ @classmethod
+ def _ensure_is_withops(cls, obj: WithOps | Expr | Number) -> WithOps:
+ if isinstance(obj, WithOps):
+ return obj
- return add(self, other)
+ if isinstance(obj, Expr):
+ return WithOps(expr=obj)
- def __radd__(self, other: WithOps | Number) -> WithOps:
- if isinstance(other, Number):
- other = WithOps(expr=PolyConst(value=other))
+ # mypy typechecker bug: https://github.com/python/mypy/issues/11673
+ elif isinstance(obj, Number): # type: ignore[misc, arg-type]
+ return WithOps(expr=PolyConst(value=obj)) # type: ignore[call-arg]
- return add(other, self)
+ raise TypeError(f"Cannot wrap {obj}!") # FIXME: Decent message
- def __sub__(self, other: Self) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __add__(self, other: WithOps | Number) -> Expr:
+ other = self._ensure_is_withops(other)
+ with self as left, other as right:
+ return MatAdd(left, right)
- def __rsub__(self, other: Self) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __radd__(self, num: Number) -> Expr:
+ other = self._ensure_is_withops(num)
+ with other as left, self as right:
+ return MatAdd(left, right)
- def __neg__(self) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __sub__(self, other: WithOps | Number) -> Expr:
+ other = self._ensure_is_withops(other)
+ with self as left, other as right:
+ return MatSub(left, right)
- def __mul__(self, other: Any) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __rsub__(self, num: Number) -> Expr:
+ other = self._ensure_is_withops(num)
+ with other as left, self as right:
+ return MatSub(left, right)
- def __rmul__(self, other: Any) -> Self:
+ def __neg__(self) -> Self:
raise NotImplementedError
- def __pow__(self, other: Any) -> Self:
- raise NotImplementedError
+ @wrap_result
+ def __mul__(self, other: WithOps | Number) -> Expr:
+ other = self._ensure_is_withops(other)
+ with self as left, other as right:
+ return MatElemMul(left, right)
+
+ @wrap_result
+ def __rmul__(self, num: Number) -> Expr:
+ other = self._ensure_is_withops(num)
+ with other as left, self as right:
+ return MatElemMul(left, right)
+
+ @wrap_result
+ def __pow__(self, num: Number) -> Expr:
+ other = self._ensure_is_withops(num)
+ with self as left, other as right:
+ # TODO: what if it is a matrix? Elementwise power?
+ return PolyExp(left, right)
def __matmul__(self, other: Any) -> Self:
raise NotImplementedError
diff --git a/mdpoly/operations/exp.py b/mdpoly/operations/exp.py
index 0f57fcb..2b17a07 100644
--- a/mdpoly/operations/exp.py
+++ b/mdpoly/operations/exp.py
@@ -8,8 +8,10 @@ from dataclasses import dataclass
from ..abc import Expr, Const
from ..errors import AlgebraicError
+from ..index import Shape
from . import BinaryOp, Reducible
+from .mul import PolyMul
# TODO: implement matrix exponential, use caley-hamilton thm magic
@@ -22,12 +24,13 @@ class MatExp(BinaryOp, Reducible):
@dataclass(eq=False)
class PolyExp(BinaryOp, Reducible):
""" Exponentiation operator between scalar polynomials. """
+ shape: Shape = Shape.scalar()
@property # type: ignore[override]
def right(self) -> Const: # type: ignore[override]
if not isinstance(super().right, Const):
- raise AlgebraicError(f"Cannot raise {self.left} to {self.right} because"
- f"{self.right} is not a constant.")
+ raise AlgebraicError(f"Cannot raise {str(self.left)} to {str(super().right)} "
+ f"because {str(super().right)} is not a constant.")
return cast(Const, super().right)
@@ -44,7 +47,7 @@ class PolyExp(BinaryOp, Reducible):
raise NotImplementedError("Cannot raise to non-integer powers (yet).")
ntimes = self.right.value - 1
- return reduce(opmul, (var for _ in range(ntimes)), var)
+ return reduce(PolyMul, (var for _ in range(ntimes)), var)
def __str__(self) -> str:
return f"({self.left} ** {self.right})"