aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-20 11:23:49 +0100
committerNao Pross <np@0hm.ch>2024-03-20 11:24:53 +0100
commit4c765763c41e904a32cf68fa7a43fa48b04fb040 (patch)
tree488935c8077ab9f9d772a1658e6029005f0cd842
parentRe-implement __matmul__ (diff)
downloadmdpoly-4c765763c41e904a32cf68fa7a43fa48b04fb040.tar.gz
mdpoly-4c765763c41e904a32cf68fa7a43fa48b04fb040.zip
Rename WithOps to Expression
-rw-r--r--mdpoly/__init__.py28
-rw-r--r--mdpoly/expressions.py60
-rw-r--r--mdpoly/operations/__init__.py2
-rw-r--r--mdpoly/state.py20
-rw-r--r--mdpoly/test/expressions.py6
5 files changed, 58 insertions, 58 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 2fbdb88..d6a9a63 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -107,7 +107,7 @@ TODO: data structure(s) that represent the polynomial
from .index import (Shape as _Shape)
-from .expressions import (WithOps as _WithOps)
+from .expressions import (Expression as _Expression)
from .leaves import (PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam,
MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam)
@@ -129,7 +129,7 @@ State = _State
# ┃╻┃┣┳┛┣━┫┣━┛┣━┛┣╸ ┣┳┛┗━┓
# ┗┻┛╹┗╸╹ ╹╹ ╹ ┗━╸╹┗╸┗━┛
-WithExtT = TypeVar("WithExtT", bound=_WithOps)
+WithExtT = TypeVar("WithExtT", bound=_Expression)
# FIXME: move out of this file
class FromHelpers:
@@ -147,48 +147,48 @@ class FromHelpers:
@dataclass
-class Constant(_WithOps, FromHelpers):
+class Constant(_Expression, FromHelpers):
""" Constant values """
def __init__(self, *args, expr=None, **kwargs):
# FIXME: make less ugly
if not expr:
- _WithOps.__init__(self, expr=_PolyConst(*args, **kwargs))
+ _Expression.__init__(self, expr=_PolyConst(*args, **kwargs))
else:
- _WithOps.__init__(self, expr=expr)
+ _Expression.__init__(self, expr=expr)
@dataclass
-class Variable(_WithOps, FromHelpers):
+class Variable(_Expression, FromHelpers):
""" Polynomial Variable """
def __init__(self, *args, expr=None, **kwargs):
if not expr:
- _WithOps.__init__(self, expr=_PolyVar(*args, **kwargs))
+ _Expression.__init__(self, expr=_PolyVar(*args, **kwargs))
else:
- _WithOps.__init__(self, expr=expr)
+ _Expression.__init__(self, expr=expr)
@dataclass
-class Parameter(_WithOps, FromHelpers):
+class Parameter(_Expression, FromHelpers):
""" Parameter that can be substituted """
def __init__(self, *args, expr=None, **kwargs):
if not expr:
- _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs))
+ _Expression.__init__(self, expr=_PolyParam(*args, **kwargs))
else:
- _WithOps.__init__(self, expr=expr)
+ _Expression.__init__(self, expr=expr)
def __hash__(self):
return hash((self.__class__.__qualname__, hash(self.unwrap())))
-class MatrixConstant(_WithOps):
+class MatrixConstant(_Expression):
""" Matrix constant """
-class MatrixVariable(_WithOps):
+class MatrixVariable(_Expression):
""" Matrix Polynomial Variable """
-class MatrixParameter(_WithOps):
+class MatrixParameter(_Expression):
""" Matrix Parameter """
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index edaf1aa..88baedf 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
@dataclass
-class WithOps:
+class Expression:
""" Monadic wrapper around :py:class:`mdpoly.abc.Expr` that adds operator
overloading and operations to expression objects. """
expr: Expr
@@ -52,12 +52,12 @@ class WithOps:
def __exit__(self, exc_type: type, exc_val: Exception, tb):
""" Context manager exit.
- See :py:meth:`mdpoly.expresssions.WithOps.__enter__`.
+ See :py:meth:`mdpoly.expresssions.Expression.__enter__`.
"""
return False # Propagate exceptions
def __str__(self):
- return f"WithOps{str(self.expr)}"
+ return f"Expression{str(self.expr)}"
def __getattr__(self, attr):
# Behave transparently
@@ -71,53 +71,53 @@ class WithOps:
return self.expr
@staticmethod
- def map(fn: Callable[[Expr], Expr]) -> Callable[[WithOps], WithOps]:
+ def map(fn: Callable[[Expr], Expr]) -> Callable[[Expression], Expression]:
""" Map in the functional programming sense.
Converts a function like `(Expr -> Expr)` into
- `(WithOps<Expr> -> WithOps<Expr>)`.
+ `(Expression<Expr> -> Expression<Expr>)`.
"""
@wraps(fn)
- def wrapper(e: WithOps) -> WithOps:
- return WithOps(expr=fn(e.expr))
+ def wrapper(e: Expression) -> Expression:
+ return Expression(expr=fn(e.expr))
return wrapper
@staticmethod
- def zip(fn: Callable[..., Expr]) -> Callable[..., WithOps]:
+ def zip(fn: Callable[..., Expr]) -> Callable[..., Expression]:
""" Zip, in the functional programming sense.
Converts a function like `(Expr -> Expr -> ... -> Expr)` into
- `(WithOps<Expr> -> WithOps<Expr> -> ... -> WithOps<Expr>)`.
+ `(Expression<Expr> -> Expression<Expr> -> ... -> Expression<Expr>)`.
"""
@wraps(fn)
- def wrapper(*args: WithOps) -> WithOps:
- return WithOps(expr=fn(*(arg.expr for arg in args)))
+ def wrapper(*args: Expression) -> Expression:
+ return Expression(expr=fn(*(arg.expr for arg in args)))
return wrapper
@staticmethod
- def bind(fn: Callable[[Expr], WithOps]) -> Callable[[WithOps], WithOps]:
+ def bind(fn: Callable[[Expr], Expression]) -> Callable[[Expression], Expression]:
""" Bind in the functional programming sense. Also known as flatmap.
- Converts a funciton like `(Expr -> WithOps)` into `(WithOps -> WithOps)`
+ Converts a funciton like `(Expr -> Expression)` into `(Expression -> Expression)`
so that it can be composed with other functions.
"""
@wraps(fn)
- def wrapper(arg: WithOps) -> WithOps:
+ def wrapper(arg: Expression) -> Expression:
return fn(arg.expr)
return wrapper
@staticmethod
- def wrap_result(meth: Callable[[WithOps, Any], Expr]) -> Callable[[WithOps, Any], WithOps]:
+ def wrap_result(meth: Callable[[Expression, Any], Expr]) -> Callable[[Expression, Any], Expression]:
""" Take a method and wrap its result type.
- Turns method `(WithOps, Any) -> Expr)` into
- `(WithOps, Any) -> WithOps)`. method arguments are left unchanged.
+ Turns method `(Expression, Any) -> Expr)` into
+ `(Expression, Any) -> Expression)`. method arguments are left unchanged.
This is only for conveniente to avoid always having to wrap the result
by hand.
"""
@wraps(meth)
- def meth_wrapper(self, *args, **kwargs) -> WithOps:
- # Why type(self)? Because if we are wrapping a subtype of WithOps,
+ def meth_wrapper(self, *args, **kwargs) -> Expression:
+ # Why type(self)? Because if we are wrapping a subtype of Expression,
# eg. from an extension we want to preserve its type. See for
# example mdpoly.test.TestExtensions.
return type(self)(expr=meth(self, *args, **kwargs))
@@ -126,22 +126,22 @@ class WithOps:
# -- Operator overloading ---
@classmethod
- def _ensure_is_withops(cls, obj: WithOps | Expr | Number) -> WithOps:
- """ Ensures that the given object is wrapped with type WithOps. """
- if isinstance(obj, WithOps):
+ def _ensure_is_withops(cls, obj: Expression | Expr | Number) -> Expression:
+ """ Ensures that the given object is wrapped with type Expression. """
+ if isinstance(obj, Expression):
return obj
if isinstance(obj, Expr):
- return WithOps(expr=obj)
+ return Expression(expr=obj)
# mypy typechecker bug: https://github.com/python/mypy/issues/11673
elif isinstance(obj, Number): # type: ignore[misc, arg-type]
- return WithOps(expr=PolyConst(value=obj)) # type: ignore[call-arg]
+ return Expression(expr=PolyConst(value=obj)) # type: ignore[call-arg]
raise TypeError(f"Cannot wrap {obj}!") # FIXME: Decent message
@wrap_result
- def __add__(self, other: WithOps | Number) -> Expr:
+ def __add__(self, other: Expression | Number) -> Expr:
other = self._ensure_is_withops(other)
with self as left, other as right:
return MatAdd(left, right)
@@ -153,7 +153,7 @@ class WithOps:
return MatAdd(left, right)
@wrap_result
- def __sub__(self, other: WithOps | Number) -> Expr:
+ def __sub__(self, other: Expression | Number) -> Expr:
other = self._ensure_is_withops(other)
with self as left, other as right:
return MatSub(left, right)
@@ -168,7 +168,7 @@ class WithOps:
raise NotImplementedError
@wrap_result
- def __mul__(self, other: WithOps | Number) -> Expr:
+ def __mul__(self, other: Expression | Number) -> Expr:
other = self._ensure_is_withops(other)
with self as left, other as right:
return MatElemMul(left, right)
@@ -187,7 +187,7 @@ class WithOps:
return PolyExp(left, right)
@wrap_result
- def __matmul__(self, other: WithOps) -> Expr:
+ def __matmul__(self, other: Expression) -> Expr:
other = self._ensure_is_withops(other)
with self as left, other as right:
return MatMul(left, right)
@@ -209,7 +209,7 @@ class WithOps:
# --- More operations ---
@wrap_result
- def diff(self, with_respect_to: WithOps) -> Expr:
+ def diff(self, with_respect_to: Expression) -> Expr:
with self as expr, with_respect_to as wrt:
if not isinstance(wrt, Var):
raise AlgebraicError(
@@ -221,7 +221,7 @@ class WithOps:
return PolyPartialDiff(expr, wrt=wrt)
@wrap_result
- def integrate(self, with_respect_to: WithOps) -> Expr:
+ def integrate(self, with_respect_to: Expression) -> Expr:
raise NotImplementedError
# -- Make representations ---
diff --git a/mdpoly/operations/__init__.py b/mdpoly/operations/__init__.py
index cd95f45..59c6ba6 100644
--- a/mdpoly/operations/__init__.py
+++ b/mdpoly/operations/__init__.py
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
class WithTraceback:
""" Dataclass that contains traceback information.
- See where it is used in :py:meth:`mdpoly.expressions.WithOps.to_repr`.
+ See where it is used in :py:meth:`mdpoly.expressions.Expresion.to_repr`.
"""
def __post_init__(self):
# NOTE: leaving this information around indefinitely can cause memory
diff --git a/mdpoly/state.py b/mdpoly/state.py
index 2555dfc..16c9b32 100644
--- a/mdpoly/state.py
+++ b/mdpoly/state.py
@@ -3,27 +3,27 @@ from typing import TYPE_CHECKING
from .abc import Expr, Var, Param
from .index import PolyVarIndex
-from .expressions import WithOps
+from .expressions import Expression
from .errors import MissingParameters
if TYPE_CHECKING:
from .index import Number
- from .expressions import WithOps
+ from .expressions import Expression
Index = int
class State:
- def __init__(self, variables: dict[Var | WithOps, Index] = {},
- parameters: dict[Param | WithOps, Number] = {}):
+ def __init__(self, variables: dict[Var | Expression, Index] = {},
+ parameters: dict[Param | Expression, Number] = {}):
self._variables: dict[Var, Index] = dict()
self._parameters: dict[Param, Number] = dict()
self._last_index: Index = -1
for var, idx in variables.items():
- v = var.unwrap() if isinstance(var, WithOps) else var
+ v = var.unwrap() if isinstance(var, Expression) else var
if not isinstance(v, Var):
raise IndexError(f"Cannot index {v} (type {type(v)}). "
@@ -34,7 +34,7 @@ class State:
# FIXME: allow matrix constants
for param, val in parameters.items():
- p = param.unwrap() if isinstance(param, WithOps) else param
+ p = param.unwrap() if isinstance(param, Expression) else param
if not isinstance(p, Param):
raise IndexError(f"Cannot get parameter {p} (type {type(p)}). "
@@ -47,9 +47,9 @@ class State:
self._last_index += 1
return self._last_index
- def index(self, var: Var | WithOps) -> Index:
+ def index(self, var: Var | Expression) -> Index:
""" Get the index for a variable. """
- if isinstance(var, WithOps):
+ if isinstance(var, Expression):
var: Expr | Var = var.unwrap() # type: ignore[no-redef]
if not isinstance(var, Var):
@@ -79,9 +79,9 @@ class State:
raise IndexError(f"There is no variable with index {index}.")
- def parameter(self, param: Param | WithOps) -> Number:
+ def parameter(self, param: Param | Expression) -> Number:
""" Get the value for a parameter. """
- if isinstance(param, WithOps):
+ if isinstance(param, Expression):
param: Param | Expr = param.unwrap() # type: ignore[no-redef]
if not isinstance(param, Param):
diff --git a/mdpoly/test/expressions.py b/mdpoly/test/expressions.py
index 50ebd7f..07422d8 100644
--- a/mdpoly/test/expressions.py
+++ b/mdpoly/test/expressions.py
@@ -8,7 +8,7 @@ from .. import Variable, Constant, Parameter, State
from ..abc import Expr
from ..errors import AlgebraicError
-from ..expressions import PolyConst, WithOps
+from ..expressions import PolyConst, Expression
from ..operations import UnaryOp, Reducible
from ..index import Shape, MatrixIndex
from ..representations import SparseRepr
@@ -82,9 +82,9 @@ class Grok(UnaryOp, Reducible):
return self.left
-class WithGrokOp(WithOps):
+class WithGrokOp(Expression):
- @WithOps.wrap_result
+ @Expression.wrap_result
def grok(self) -> Expr:
with self as inner:
return Grok(inner)