diff options
author | Nao Pross <np@0hm.ch> | 2024-03-03 01:10:58 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-03 01:10:58 +0100 |
commit | bfc3695fb426606c71a06b5a0a75571033889610 (patch) | |
tree | bd93a2e1d42db5b4b8721fc1875508b8e8b5ea78 | |
parent | Fix State.from_index crash when trying to get constant term (diff) | |
download | mdpoly-bfc3695fb426606c71a06b5a0a75571033889610.tar.gz mdpoly-bfc3695fb426606c71a06b5a0a75571033889610.zip |
Implement HasRepr for Param and fix repr of Const
Diffstat (limited to '')
-rw-r--r-- | mdpoly/errors.py | 5 | ||||
-rw-r--r-- | mdpoly/leaves.py | 16 |
2 files changed, 17 insertions, 4 deletions
diff --git a/mdpoly/errors.py b/mdpoly/errors.py index 1531a84..578c5c8 100644 --- a/mdpoly/errors.py +++ b/mdpoly/errors.py @@ -9,3 +9,8 @@ class InvalidShape(Exception): """ This is raised whenever there is an invalid shape """ + +class MissingParameters(Exception): + """ + This is raised whenever + """ diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index 8a018e6..a9d2005 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -1,6 +1,7 @@ from .abc import Leaf, Repr from .types import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex from .state import State +from .errors import MissingParameters from .representations import HasRepr from typing import TypeVar @@ -33,10 +34,10 @@ class Const(Leaf, HasRepr): return r, state def __repr__(self) -> str: - if self.name: - return self.name + if not self.name: + return str(self.value) - return repr(self.value) + return self.name @dataclass(frozen=True) @@ -58,12 +59,19 @@ class Var(Leaf, HasRepr): @dataclass(frozen=True) -class Param(Leaf): +class Param(Leaf, HasRepr): """ A parameter """ name: str shape: Shape = Shape.scalar + def to_repr(self, repr_type: type[T], state: State) -> T: + if self not in state.parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {self} was not given.") + + return Const(state.parameters[self]).to_repr(repr_type, state) + def __repr__(self) -> str: return self.name |