aboutsummaryrefslogtreecommitdiffstats
path: root/mdpoly/state.py
blob: 092de84bfb620890abdeae9584a6a7cd26916943 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations
from typing import TYPE_CHECKING

from dataclasses import dataclass, field

from .index import PolyVarIndex
from .abc import Var, Param

if TYPE_CHECKING:
    from .index import Number


Index = int


@dataclass
class State:
    variables: dict[Var, Index] = field(default_factory=dict)
    parameters: dict[Param, Number] = field(default_factory=dict)
    _last_index: Index = -1 

    def _make_index(self) -> Index:
        self._last_index += 1
        return self._last_index

    def index(self, var: Var) -> Index:
        """ Get the index for a variable. """
        if not isinstance(var, Var):
            raise IndexError(f"Cannot index {var} (type {type(var)}). "
                    f"Only variables (type {Var}) can be indexed.")
        
        if var not in self.variables.keys():
            new_index = self._make_index()
            self.variables[var] = new_index
            return new_index

        return self.variables[var]

    def from_index(self, index: Index | PolyVarIndex) -> Var | Number:
        """ Get a variable object from the index.
        This is a reverse lookup operation and it is **not**
        :math:`\mathcal{O}(1)`!"""
        if isinstance(index, PolyVarIndex):
            index = index.var_idx

        if index == -1:
            return 1

        for var, idx in self.variables.items():
            if idx == index:
                return var

        raise IndexError(f"There is no variable with index {index}.")