diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-09 10:19:31 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-09 10:19:31 +0200 |
commit | 1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc (patch) | |
tree | f145d2a05ac223e1ab9caa4b2210a15e9b4d8635 | |
parent | add trace operator (diff) | |
download | polymatrix-1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc.tar.gz polymatrix-1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc.zip |
allow dict in eval expression
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/init/initevalexpr.py | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 49bb0a3..dad9664 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -10,9 +10,14 @@ def init_eval_expr( ): if values is None: - assert isinstance(variables, tuple) + if isinstance(variables, tuple): + variables, values = tuple(zip(*variables)) - variables, values = tuple(zip(*variables)) + elif isinstance(variables, dict): + variables, values = tuple(zip(*variables.items())) + + else: + raise Exception(f'unsupported case {variables=}') elif isinstance(values, np.ndarray): values = tuple(values.reshape(-1)) @@ -20,11 +25,31 @@ def init_eval_expr( elif not isinstance(values, tuple): values = (values,) - # if not isinstance(variables, tuple): - # variables = (variables,) + if not isinstance(variables, tuple): + variables = (variables,) + + def gen_formatted_values(): + for value in values: + if isinstance(value, np.ndarray): + yield from value.reshape(-1) + + elif isinstance(value, tuple): + yield from value + + elif isinstance(value, dict): + for variable in variables: + yield from value[variable] + + elif isinstance(value, int) or isinstance(value, float): + # else: + yield value + + else: + yield float(value) + # raise Exception(f'{value=}, {type(value)=}') return EvalExprImpl( underlying=underlying, variables=variables, - values=values, + values=tuple(gen_formatted_values()), ) |