aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/__init__.py3
-rw-r--r--mdpoly/expressions.py17
-rw-r--r--mdpoly/operations/mul.py1
-rw-r--r--mdpoly/state.py66
4 files changed, 69 insertions, 18 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index af44d8a..88c6ced 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -160,6 +160,9 @@ class Parameter(_WithOps, FromHelpers):
def __init__(self, *args, **kwargs):
_WithOps.__init__(self, expr=_PolyParam(*args, **kwargs))
+ def __hash__(self):
+ return hash((self.__class__.__qualname__, hash(self.unwrap())))
+
class MatrixConstant(_WithOps):
""" Matrix constant """
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index 57f015b..dd033f0 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -136,11 +136,8 @@ class PolyParam(Param):
shape: Shape = Shape.scalar() # overloads PolyExpr.shape
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- if self not in state.parameters:
- raise MissingParameters("Cannot construct representation because "
- f"value for parameter {self} was not given.")
-
- return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
+ p = state.parameter(self)
+ return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg]
# ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓
@@ -157,7 +154,7 @@ class WithOps:
# -- Magic methods ---
def __enter__(self) -> Expr:
- return self.expr
+ return self.unwrap()
def __exit__(self, *ex):
return False # Propagate exceptions
@@ -165,8 +162,16 @@ class WithOps:
def __str__(self):
return f"WithOps{str(self.expr)}"
+ def __getattr__(self, attr):
+ # Behave transparently
+ return getattr(self.unwrap(), attr)
+
# -- Monadic operations --
+ def unwrap(self) -> Expr:
+ """ Un-wrap value, get the raw expression without operations. """
+ return self.expr
+
@staticmethod
def map(fn: Callable[[Expr], Expr]) -> Callable[[WithOps], WithOps]:
""" Map in the functional programming sense.
diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py
index 7e75d26..1c5a629 100644
--- a/mdpoly/operations/mul.py
+++ b/mdpoly/operations/mul.py
@@ -157,6 +157,7 @@ class MatDotProd(BinaryOp, Reducible):
@dataclass(eq=False)
class PolyMul(BinaryOp):
""" Multiplication operator between scalar polynomials. """
+ shape: Shape = Shape.scalar()
def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
diff --git a/mdpoly/state.py b/mdpoly/state.py
index 092de84..2555dfc 100644
--- a/mdpoly/state.py
+++ b/mdpoly/state.py
@@ -1,40 +1,67 @@
from __future__ import annotations
from typing import TYPE_CHECKING
-from dataclasses import dataclass, field
-
+from .abc import Expr, Var, Param
from .index import PolyVarIndex
-from .abc import Var, Param
+from .expressions import WithOps
+from .errors import MissingParameters
if TYPE_CHECKING:
from .index import Number
+ from .expressions import WithOps
Index = int
-@dataclass
class State:
- variables: dict[Var, Index] = field(default_factory=dict)
- parameters: dict[Param, Number] = field(default_factory=dict)
- _last_index: Index = -1
+ def __init__(self, variables: dict[Var | WithOps, Index] = {},
+ parameters: dict[Param | WithOps, Number] = {}):
+
+ self._variables: dict[Var, Index] = dict()
+ self._parameters: dict[Param, Number] = dict()
+ self._last_index: Index = -1
+
+ for var, idx in variables.items():
+ v = var.unwrap() if isinstance(var, WithOps) else var
+
+ if not isinstance(v, Var):
+ raise IndexError(f"Cannot index {v} (type {type(v)}). "
+ f"Only variables (type {Var}) can be indexed.")
+
+ self._variables[v] = idx
+ self._last_index = max(self._last_index, idx)
+
+ # FIXME: allow matrix constants
+ for param, val in parameters.items():
+ p = param.unwrap() if isinstance(param, WithOps) else param
+
+ if not isinstance(p, Param):
+ raise IndexError(f"Cannot get parameter {p} (type {type(p)}). "
+ f"Because its type is not {Param}.")
+
+ self._parameters[p] = val
+
def _make_index(self) -> Index:
self._last_index += 1
return self._last_index
- def index(self, var: Var) -> Index:
+ def index(self, var: Var | WithOps) -> Index:
""" Get the index for a variable. """
+ if isinstance(var, WithOps):
+ var: Expr | Var = var.unwrap() # type: ignore[no-redef]
+
if not isinstance(var, Var):
raise IndexError(f"Cannot index {var} (type {type(var)}). "
f"Only variables (type {Var}) can be indexed.")
- if var not in self.variables.keys():
+ if var not in self._variables.keys():
new_index = self._make_index()
- self.variables[var] = new_index
+ self._variables[var] = new_index
return new_index
- return self.variables[var]
+ return self._variables[var]
def from_index(self, index: Index | PolyVarIndex) -> Var | Number:
""" Get a variable object from the index.
@@ -46,8 +73,23 @@ class State:
if index == -1:
return 1
- for var, idx in self.variables.items():
+ for var, idx in self._variables.items():
if idx == index:
return var
raise IndexError(f"There is no variable with index {index}.")
+
+ def parameter(self, param: Param | WithOps) -> Number:
+ """ Get the value for a parameter. """
+ if isinstance(param, WithOps):
+ param: Param | Expr = param.unwrap() # type: ignore[no-redef]
+
+ if not isinstance(param, Param):
+ raise IndexError(f"Cannot get parameter {param} (type {type(param)}). "
+ f"Because its type is not {Param}.")
+
+ if param not in self._parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {param} was not given.")
+
+ return self._parameters[param]