summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/init/initevalexpr.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/init/initevalexpr.py')
-rw-r--r--polymatrix/expression/init/initevalexpr.py182
1 files changed, 149 insertions, 33 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 5bbd404..0fad8e5 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -1,55 +1,171 @@
+import typing
import numpy as np
+from polymatrix.expression.init.initsubstituteexpr import format_substitutions
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.evalexprimpl import EvalExprImpl
def init_eval_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
- values: tuple = None,
+ variables: typing.Union[typing.Any, tuple, dict],
+ values: typing.Union[float, tuple] = None,
):
- if values is None:
- if isinstance(variables, tuple):
- variables, values = tuple(zip(*variables))
+ substitutions = format_substitutions(
+ variables=variables,
+ values=values,
+ )
+
+ def formatted_values(value):
+ if isinstance(value, np.ndarray):
+ return tuple(value.reshape(-1))
+
+ elif isinstance(value, tuple):
+ return value
- elif isinstance(variables, dict):
- variables, values = tuple(zip(*variables.items()))
+ elif isinstance(value, int) or isinstance(value, float):
+ return (value,)
else:
- raise Exception(f'unsupported case {variables=}')
+ return (float(value),)
+
+ substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions)
+
+ return EvalExprImpl(
+ underlying=underlying,
+ substitutions=substitutions,
+ )
- elif isinstance(values, np.ndarray):
- values = tuple(values.reshape(-1))
+ # if values is not None:
+ # if isinstance(variables, tuple):
+ # if isinstance(values, tuple):
+ # assert len(variables) == len(values), f'{variables=}, {values=}'
- elif not isinstance(values, tuple):
- values = (values,)
+ # else:
+ # values = tuple(values for _ in variables)
- if not isinstance(variables, tuple):
- variables = (variables,)
+ # else:
+ # variables = (variables,)
+ # values = (values,)
- def gen_formatted_values():
- for value in values:
- if isinstance(value, np.ndarray):
- yield from value.reshape(-1)
+ # subs = zip(variables, values)
+
+ # elif isinstance(variables, dict):
+ # subs = variables.items()
- elif isinstance(value, tuple):
- yield from value
+ # elif isinstance(variables, tuple):
+ # subs = variables
- elif isinstance(value, dict):
- for variable in variables:
- yield from value[variable]
+ # else:
+ # raise Exception(f'{variables=}')
- elif isinstance(value, int) or isinstance(value, float):
- yield value
+ # def formatted_values(value):
+ # if isinstance(value, np.ndarray):
+ # return tuple(value.reshape(-1))
- else:
- yield float(value)
+ # elif isinstance(value, tuple):
+ # return value
- values = tuple(gen_formatted_values())
+ # elif isinstance(value, int) or isinstance(value, float):
+ # return (value,)
- return EvalExprImpl(
- underlying=underlying,
- variables=variables,
- values=values,
-)
+ # else:
+ # return (float(value),)
+
+ # subs = tuple((var, formatted_values(val)) for var, val in subs)
+
+ # def formatted_values(value):
+ # # def gen_formatted_values():
+ # # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # # elif isinstance(value, dict):
+ # # for variable in variables:
+ # # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+ # return tuple(gen_formatted_values())
+
+
+ # if values is None:
+ # if isinstance(variables, tuple):
+ # variables, values = tuple(zip(*variables))
+
+ # elif isinstance(variables, dict):
+ # variables, values = tuple(zip(*variables.items()))
+
+ # else:
+ # raise Exception(f'unsupported case {variables=}')
+
+ # elif isinstance(values, np.ndarray):
+ # values = tuple(values.reshape(-1))
+
+ # elif not isinstance(values, tuple):
+ # values = (values,)
+
+ # if not isinstance(variables, tuple):
+ # variables = (variables,)
+
+ # def gen_formatted_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+
+ # values = tuple(gen_formatted_values())
+
+ # if len(values) == 1:
+ # values = tuple((values[0],) for _ in variables)
+
+ # else:
+ # assert len(variables) == len(values), f'length of {variables} does not match length of {values}'
+
+ # def gen_flattened_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield from value.reshape(-1)
+
+ # elif isinstance(value, tuple):
+ # yield from value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield value
+
+ # else:
+ # yield float(value)
+
+ # values = tuple(gen_flattened_values())
+
+# return EvalExprImpl(
+# underlying=underlying,
+# variables=variables,
+# values=values,
+# )