diff options
-rw-r--r-- | mdpoly/__init__.py | 30 | ||||
-rw-r--r-- | mdpoly/expressions.py | 15 | ||||
-rw-r--r-- | mdpoly/test/__main__.py | 2 | ||||
-rw-r--r-- | mdpoly/test/expressions.py | 32 |
4 files changed, 68 insertions, 11 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index 88c6ced..afa3e27 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -113,7 +113,7 @@ from .expressions import (WithOps as _WithOps, from .state import State as _State -from typing import Self, Iterable +from typing import Self, Iterable, Type, TypeVar from dataclasses import dataclass @@ -129,6 +129,8 @@ State = _State # ┃╻┃┣┳┛┣━┫┣━┛┣━┛┣╸ ┣┳┛┗━┓ # ┗┻┛╹┗╸╹ ╹╹ ╹ ┗━╸╹┗╸┗━┛ +WithExtT = TypeVar("WithExtT", bound=_WithOps) + # FIXME: move out of this file class FromHelpers: @classmethod @@ -139,26 +141,40 @@ class FromHelpers: names = map(str.strip, names) yield from map(cls, names) + @classmethod + def from_extension(cls, name, ext: Type[WithExtT]) -> WithExtT: + return ext(expr=cls(name).expr) + @dataclass class Constant(_WithOps, FromHelpers): """ Constant values """ - def __init__(self, *args, **kwargs): - _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs)) + def __init__(self, *args, expr=None, **kwargs): + # FIXME: make less ugly + if not expr: + _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs)) + else: + _WithOps.__init__(self, expr=expr) @dataclass class Variable(_WithOps, FromHelpers): """ Polynomial Variable """ - def __init__(self, *args, **kwargs): - _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs)) + def __init__(self, *args, expr=None, **kwargs): + if not expr: + _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs)) + else: + _WithOps.__init__(self, expr=expr) @dataclass class Parameter(_WithOps, FromHelpers): """ Parameter that can be substituted """ - def __init__(self, *args, **kwargs): - _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs)) + def __init__(self, *args, expr=None, **kwargs): + if not expr: + _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs)) + else: + _WithOps.__init__(self, expr=expr) def __hash__(self): return hash((self.__class__.__qualname__, hash(self.unwrap()))) diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index d7315b0..6d855a3 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -13,6 +13,7 @@ from .errors import MissingParameters, AlgebraicError from .operations.add import MatAdd, MatSub from .operations.mul import MatElemMul from .operations.exp import PolyExp +from .operations.derivative import PolyPartialDiff if TYPE_CHECKING: from .abc import ReprT @@ -149,6 +150,17 @@ class WithOps: # -- Magic methods --- def __enter__(self) -> Expr: + """ Allow nicer notation using the `with` statement. + + ... code:: py + f(x.expr) # replace this or + f(x.unwrap()) # this with + + with x as expr: + f(expr) + ... # do more stuff with expr + """ + # TODO: not very idiomatic, how much is this unidiomatic? return self.unwrap() def __exit__(self, *ex): @@ -159,6 +171,7 @@ class WithOps: def __getattr__(self, attr): # Behave transparently + # TODO: replace with selected attributes to expose? return getattr(self.unwrap(), attr) # -- Monadic operations -- @@ -207,7 +220,7 @@ class WithOps: def wrap_result(meth: Callable[[WithOps, Any], Expr]) -> Callable[[WithOps, WithOps], WithOps]: @wraps(meth) def meth_wrapper(self, *args, **kwargs) -> WithOps: - return WithOps(expr=meth(self, *args, **kwargs)) + return type(self)(expr=meth(self, *args, **kwargs)) return meth_wrapper # -- Operator overloading --- diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py index 722065f..082862c 100644 --- a/mdpoly/test/__main__.py +++ b/mdpoly/test/__main__.py @@ -1,6 +1,6 @@ import unittest -from .expressions import TestPolyExpressions +from .expressions import TestPolyExpressions, TestExtension from .index import TestPolyVarIndex, TestPolyIndex unittest.main() diff --git a/mdpoly/test/expressions.py b/mdpoly/test/expressions.py index 1520f12..50ebd7f 100644 --- a/mdpoly/test/expressions.py +++ b/mdpoly/test/expressions.py @@ -8,8 +8,9 @@ from .. import Variable, Constant, Parameter, State from ..abc import Expr from ..errors import AlgebraicError -from ..expressions import PolyConst -from ..index import MatrixIndex +from ..expressions import PolyConst, WithOps +from ..operations import UnaryOp, Reducible +from ..index import Shape, MatrixIndex from ..representations import SparseRepr @@ -67,4 +68,31 @@ class TestPolyExpressions(TestCase): self.assertIsInstance(result.left, PolyConst) +# ┏━╸╻ ╻╺┳╸┏━╸┏┓╻┏━┓╻┏━┓┏┓╻┏━┓ +# ┣╸ ┏╋┛ ┃ ┣╸ ┃┗┫┗━┓┃┃ ┃┃┗┫┗━┓ +# ┗━╸╹ ╹ ╹ ┗━╸╹ ╹┗━┛╹┗━┛╹ ╹┗━┛ + +class Grok(UnaryOp, Reducible): + + @property + def shape(self) -> Shape: + return self.left.shape + + def reduce(self) -> Expr: + return self.left + + +class WithGrokOp(WithOps): + + @WithOps.wrap_result + def grok(self) -> Expr: + with self as inner: + return Grok(inner) + + +class TestExtension(TestCase): + + def test_grok_extension(self): + x = Variable.from_extension("x", WithGrokOp) + (x ** 2).grok() |