summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-09 10:19:31 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-09 10:19:31 +0200
commit1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc (patch)
treef145d2a05ac223e1ab9caa4b2210a15e9b4d8635
parentadd trace operator (diff)
downloadpolymatrix-1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc.tar.gz
polymatrix-1e9ce38579d52aa2e0c8a2b87ecebb59ef4427fc.zip
allow dict in eval expression
-rw-r--r--polymatrix/expression/init/initevalexpr.py35
1 files changed, 30 insertions, 5 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 49bb0a3..dad9664 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -10,9 +10,14 @@ def init_eval_expr(
):
if values is None:
- assert isinstance(variables, tuple)
+ if isinstance(variables, tuple):
+ variables, values = tuple(zip(*variables))
- variables, values = tuple(zip(*variables))
+ elif isinstance(variables, dict):
+ variables, values = tuple(zip(*variables.items()))
+
+ else:
+ raise Exception(f'unsupported case {variables=}')
elif isinstance(values, np.ndarray):
values = tuple(values.reshape(-1))
@@ -20,11 +25,31 @@ def init_eval_expr(
elif not isinstance(values, tuple):
values = (values,)
- # if not isinstance(variables, tuple):
- # variables = (variables,)
+ if not isinstance(variables, tuple):
+ variables = (variables,)
+
+ def gen_formatted_values():
+ for value in values:
+ if isinstance(value, np.ndarray):
+ yield from value.reshape(-1)
+
+ elif isinstance(value, tuple):
+ yield from value
+
+ elif isinstance(value, dict):
+ for variable in variables:
+ yield from value[variable]
+
+ elif isinstance(value, int) or isinstance(value, float):
+ # else:
+ yield value
+
+ else:
+ yield float(value)
+ # raise Exception(f'{value=}, {type(value)=}')
return EvalExprImpl(
underlying=underlying,
variables=variables,
- values=values,
+ values=tuple(gen_formatted_values()),
)