diff options
-rw-r--r-- | mdpoly/__init__.py | 3 | ||||
-rw-r--r-- | mdpoly/expressions.py | 17 | ||||
-rw-r--r-- | mdpoly/operations/mul.py | 1 | ||||
-rw-r--r-- | mdpoly/state.py | 66 |
4 files changed, 69 insertions, 18 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index af44d8a..88c6ced 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -160,6 +160,9 @@ class Parameter(_WithOps, FromHelpers): def __init__(self, *args, **kwargs): _WithOps.__init__(self, expr=_PolyParam(*args, **kwargs)) + def __hash__(self): + return hash((self.__class__.__qualname__, hash(self.unwrap()))) + class MatrixConstant(_WithOps): """ Matrix constant """ diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index 57f015b..dd033f0 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -136,11 +136,8 @@ class PolyParam(Param): shape: Shape = Shape.scalar() # overloads PolyExpr.shape def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - if self not in state.parameters: - raise MissingParameters("Cannot construct representation because " - f"value for parameter {self} was not given.") - - return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] + p = state.parameter(self) + return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg] # ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓ @@ -157,7 +154,7 @@ class WithOps: # -- Magic methods --- def __enter__(self) -> Expr: - return self.expr + return self.unwrap() def __exit__(self, *ex): return False # Propagate exceptions @@ -165,8 +162,16 @@ class WithOps: def __str__(self): return f"WithOps{str(self.expr)}" + def __getattr__(self, attr): + # Behave transparently + return getattr(self.unwrap(), attr) + # -- Monadic operations -- + def unwrap(self) -> Expr: + """ Un-wrap value, get the raw expression without operations. """ + return self.expr + @staticmethod def map(fn: Callable[[Expr], Expr]) -> Callable[[WithOps], WithOps]: """ Map in the functional programming sense. diff --git a/mdpoly/operations/mul.py b/mdpoly/operations/mul.py index 7e75d26..1c5a629 100644 --- a/mdpoly/operations/mul.py +++ b/mdpoly/operations/mul.py @@ -157,6 +157,7 @@ class MatDotProd(BinaryOp, Reducible): @dataclass(eq=False) class PolyMul(BinaryOp): """ Multiplication operator between scalar polynomials. """ + shape: Shape = Shape.scalar() def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ diff --git a/mdpoly/state.py b/mdpoly/state.py index 092de84..2555dfc 100644 --- a/mdpoly/state.py +++ b/mdpoly/state.py @@ -1,40 +1,67 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dataclasses import dataclass, field - +from .abc import Expr, Var, Param from .index import PolyVarIndex -from .abc import Var, Param +from .expressions import WithOps +from .errors import MissingParameters if TYPE_CHECKING: from .index import Number + from .expressions import WithOps Index = int -@dataclass class State: - variables: dict[Var, Index] = field(default_factory=dict) - parameters: dict[Param, Number] = field(default_factory=dict) - _last_index: Index = -1 + def __init__(self, variables: dict[Var | WithOps, Index] = {}, + parameters: dict[Param | WithOps, Number] = {}): + + self._variables: dict[Var, Index] = dict() + self._parameters: dict[Param, Number] = dict() + self._last_index: Index = -1 + + for var, idx in variables.items(): + v = var.unwrap() if isinstance(var, WithOps) else var + + if not isinstance(v, Var): + raise IndexError(f"Cannot index {v} (type {type(v)}). " + f"Only variables (type {Var}) can be indexed.") + + self._variables[v] = idx + self._last_index = max(self._last_index, idx) + + # FIXME: allow matrix constants + for param, val in parameters.items(): + p = param.unwrap() if isinstance(param, WithOps) else param + + if not isinstance(p, Param): + raise IndexError(f"Cannot get parameter {p} (type {type(p)}). " + f"Because its type is not {Param}.") + + self._parameters[p] = val + def _make_index(self) -> Index: self._last_index += 1 return self._last_index - def index(self, var: Var) -> Index: + def index(self, var: Var | WithOps) -> Index: """ Get the index for a variable. """ + if isinstance(var, WithOps): + var: Expr | Var = var.unwrap() # type: ignore[no-redef] + if not isinstance(var, Var): raise IndexError(f"Cannot index {var} (type {type(var)}). " f"Only variables (type {Var}) can be indexed.") - if var not in self.variables.keys(): + if var not in self._variables.keys(): new_index = self._make_index() - self.variables[var] = new_index + self._variables[var] = new_index return new_index - return self.variables[var] + return self._variables[var] def from_index(self, index: Index | PolyVarIndex) -> Var | Number: """ Get a variable object from the index. @@ -46,8 +73,23 @@ class State: if index == -1: return 1 - for var, idx in self.variables.items(): + for var, idx in self._variables.items(): if idx == index: return var raise IndexError(f"There is no variable with index {index}.") + + def parameter(self, param: Param | WithOps) -> Number: + """ Get the value for a parameter. """ + if isinstance(param, WithOps): + param: Param | Expr = param.unwrap() # type: ignore[no-redef] + + if not isinstance(param, Param): + raise IndexError(f"Cannot get parameter {param} (type {type(param)}). " + f"Because its type is not {Param}.") + + if param not in self._parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {param} was not given.") + + return self._parameters[param] |