diff options
-rw-r--r-- | mdpoly/errors.py | 5 | ||||
-rw-r--r-- | mdpoly/leaves.py | 16 |
2 files changed, 17 insertions, 4 deletions
diff --git a/mdpoly/errors.py b/mdpoly/errors.py index 1531a84..578c5c8 100644 --- a/mdpoly/errors.py +++ b/mdpoly/errors.py @@ -9,3 +9,8 @@ class InvalidShape(Exception): """ This is raised whenever there is an invalid shape """ + +class MissingParameters(Exception): + """ + This is raised whenever + """ diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index 8a018e6..a9d2005 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -1,6 +1,7 @@ from .abc import Leaf, Repr from .types import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex from .state import State +from .errors import MissingParameters from .representations import HasRepr from typing import TypeVar @@ -33,10 +34,10 @@ class Const(Leaf, HasRepr): return r, state def __repr__(self) -> str: - if self.name: - return self.name + if not self.name: + return str(self.value) - return repr(self.value) + return self.name @dataclass(frozen=True) @@ -58,12 +59,19 @@ class Var(Leaf, HasRepr): @dataclass(frozen=True) -class Param(Leaf): +class Param(Leaf, HasRepr): """ A parameter """ name: str shape: Shape = Shape.scalar + def to_repr(self, repr_type: type[T], state: State) -> T: + if self not in state.parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {self} was not given.") + + return Const(state.parameters[self]).to_repr(repr_type, state) + def __repr__(self) -> str: return self.name |