aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/__init__.py30
-rw-r--r--mdpoly/expressions.py15
-rw-r--r--mdpoly/test/__main__.py2
-rw-r--r--mdpoly/test/expressions.py32
4 files changed, 68 insertions, 11 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 88c6ced..afa3e27 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -113,7 +113,7 @@ from .expressions import (WithOps as _WithOps,
from .state import State as _State
-from typing import Self, Iterable
+from typing import Self, Iterable, Type, TypeVar
from dataclasses import dataclass
@@ -129,6 +129,8 @@ State = _State
# ┃╻┃┣┳┛┣━┫┣━┛┣━┛┣╸ ┣┳┛┗━┓
# ┗┻┛╹┗╸╹ ╹╹ ╹ ┗━╸╹┗╸┗━┛
+WithExtT = TypeVar("WithExtT", bound=_WithOps)
+
# FIXME: move out of this file
class FromHelpers:
@classmethod
@@ -139,26 +141,40 @@ class FromHelpers:
names = map(str.strip, names)
yield from map(cls, names)
+ @classmethod
+ def from_extension(cls, name, ext: Type[WithExtT]) -> WithExtT:
+ return ext(expr=cls(name).expr)
+
@dataclass
class Constant(_WithOps, FromHelpers):
""" Constant values """
- def __init__(self, *args, **kwargs):
- _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs))
+ def __init__(self, *args, expr=None, **kwargs):
+ # FIXME: make less ugly
+ if not expr:
+ _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs))
+ else:
+ _WithOps.__init__(self, expr=expr)
@dataclass
class Variable(_WithOps, FromHelpers):
""" Polynomial Variable """
- def __init__(self, *args, **kwargs):
- _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs))
+ def __init__(self, *args, expr=None, **kwargs):
+ if not expr:
+ _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs))
+ else:
+ _WithOps.__init__(self, expr=expr)
@dataclass
class Parameter(_WithOps, FromHelpers):
""" Parameter that can be substituted """
- def __init__(self, *args, **kwargs):
- _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs))
+ def __init__(self, *args, expr=None, **kwargs):
+ if not expr:
+ _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs))
+ else:
+ _WithOps.__init__(self, expr=expr)
def __hash__(self):
return hash((self.__class__.__qualname__, hash(self.unwrap())))
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index d7315b0..6d855a3 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -13,6 +13,7 @@ from .errors import MissingParameters, AlgebraicError
from .operations.add import MatAdd, MatSub
from .operations.mul import MatElemMul
from .operations.exp import PolyExp
+from .operations.derivative import PolyPartialDiff
if TYPE_CHECKING:
from .abc import ReprT
@@ -149,6 +150,17 @@ class WithOps:
# -- Magic methods ---
def __enter__(self) -> Expr:
+ """ Allow nicer notation using the `with` statement.
+
+ ... code:: py
+ f(x.expr) # replace this or
+ f(x.unwrap()) # this with
+
+ with x as expr:
+ f(expr)
+ ... # do more stuff with expr
+ """
+ # TODO: not very idiomatic, how much is this unidiomatic?
return self.unwrap()
def __exit__(self, *ex):
@@ -159,6 +171,7 @@ class WithOps:
def __getattr__(self, attr):
# Behave transparently
+ # TODO: replace with selected attributes to expose?
return getattr(self.unwrap(), attr)
# -- Monadic operations --
@@ -207,7 +220,7 @@ class WithOps:
def wrap_result(meth: Callable[[WithOps, Any], Expr]) -> Callable[[WithOps, WithOps], WithOps]:
@wraps(meth)
def meth_wrapper(self, *args, **kwargs) -> WithOps:
- return WithOps(expr=meth(self, *args, **kwargs))
+ return type(self)(expr=meth(self, *args, **kwargs))
return meth_wrapper
# -- Operator overloading ---
diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py
index 722065f..082862c 100644
--- a/mdpoly/test/__main__.py
+++ b/mdpoly/test/__main__.py
@@ -1,6 +1,6 @@
import unittest
-from .expressions import TestPolyExpressions
+from .expressions import TestPolyExpressions, TestExtension
from .index import TestPolyVarIndex, TestPolyIndex
unittest.main()
diff --git a/mdpoly/test/expressions.py b/mdpoly/test/expressions.py
index 1520f12..50ebd7f 100644
--- a/mdpoly/test/expressions.py
+++ b/mdpoly/test/expressions.py
@@ -8,8 +8,9 @@ from .. import Variable, Constant, Parameter, State
from ..abc import Expr
from ..errors import AlgebraicError
-from ..expressions import PolyConst
-from ..index import MatrixIndex
+from ..expressions import PolyConst, WithOps
+from ..operations import UnaryOp, Reducible
+from ..index import Shape, MatrixIndex
from ..representations import SparseRepr
@@ -67,4 +68,31 @@ class TestPolyExpressions(TestCase):
self.assertIsInstance(result.left, PolyConst)
+# ┏━╸╻ ╻╺┳╸┏━╸┏┓╻┏━┓╻┏━┓┏┓╻┏━┓
+# ┣╸ ┏╋┛ ┃ ┣╸ ┃┗┫┗━┓┃┃ ┃┃┗┫┗━┓
+# ┗━╸╹ ╹ ╹ ┗━╸╹ ╹┗━┛╹┗━┛╹ ╹┗━┛
+
+class Grok(UnaryOp, Reducible):
+
+ @property
+ def shape(self) -> Shape:
+ return self.left.shape
+
+ def reduce(self) -> Expr:
+ return self.left
+
+
+class WithGrokOp(WithOps):
+
+ @WithOps.wrap_result
+ def grok(self) -> Expr:
+ with self as inner:
+ return Grok(inner)
+
+
+class TestExtension(TestCase):
+
+ def test_grok_extension(self):
+ x = Variable.from_extension("x", WithGrokOp)
+ (x ** 2).grok()