diff options
Diffstat (limited to 'polymatrix/expression/init/initevalexpr.py')
-rw-r--r-- | polymatrix/expression/init/initevalexpr.py | 182 |
1 files changed, 149 insertions, 33 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 5bbd404..0fad8e5 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -1,55 +1,171 @@ +import typing import numpy as np +from polymatrix.expression.init.initsubstituteexpr import format_substitutions from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.evalexprimpl import EvalExprImpl def init_eval_expr( underlying: ExpressionBaseMixin, - variables: tuple, - values: tuple = None, + variables: typing.Union[typing.Any, tuple, dict], + values: typing.Union[float, tuple] = None, ): - if values is None: - if isinstance(variables, tuple): - variables, values = tuple(zip(*variables)) + substitutions = format_substitutions( + variables=variables, + values=values, + ) + + def formatted_values(value): + if isinstance(value, np.ndarray): + return tuple(value.reshape(-1)) + + elif isinstance(value, tuple): + return value - elif isinstance(variables, dict): - variables, values = tuple(zip(*variables.items())) + elif isinstance(value, int) or isinstance(value, float): + return (value,) else: - raise Exception(f'unsupported case {variables=}') + return (float(value),) + + substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions) + + return EvalExprImpl( + underlying=underlying, + substitutions=substitutions, + ) - elif isinstance(values, np.ndarray): - values = tuple(values.reshape(-1)) + # if values is not None: + # if isinstance(variables, tuple): + # if isinstance(values, tuple): + # assert len(variables) == len(values), f'{variables=}, {values=}' - elif not isinstance(values, tuple): - values = (values,) + # else: + # values = tuple(values for _ in variables) - if not isinstance(variables, tuple): - variables = (variables,) + # else: + # variables = (variables,) + # values = (values,) - def gen_formatted_values(): - for value in values: - if isinstance(value, np.ndarray): - yield from value.reshape(-1) + # subs = zip(variables, values) + + # elif isinstance(variables, dict): + # subs = variables.items() - elif isinstance(value, tuple): - yield from value + # elif isinstance(variables, tuple): + # subs = variables - elif isinstance(value, dict): - for variable in variables: - yield from value[variable] + # else: + # raise Exception(f'{variables=}') - elif isinstance(value, int) or isinstance(value, float): - yield value + # def formatted_values(value): + # if isinstance(value, np.ndarray): + # return tuple(value.reshape(-1)) - else: - yield float(value) + # elif isinstance(value, tuple): + # return value - values = tuple(gen_formatted_values()) + # elif isinstance(value, int) or isinstance(value, float): + # return (value,) - return EvalExprImpl( - underlying=underlying, - variables=variables, - values=values, -) + # else: + # return (float(value),) + + # subs = tuple((var, formatted_values(val)) for var, val in subs) + + # def formatted_values(value): + # # def gen_formatted_values(): + # # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # # elif isinstance(value, dict): + # # for variable in variables: + # # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + # return tuple(gen_formatted_values()) + + + # if values is None: + # if isinstance(variables, tuple): + # variables, values = tuple(zip(*variables)) + + # elif isinstance(variables, dict): + # variables, values = tuple(zip(*variables.items())) + + # else: + # raise Exception(f'unsupported case {variables=}') + + # elif isinstance(values, np.ndarray): + # values = tuple(values.reshape(-1)) + + # elif not isinstance(values, tuple): + # values = (values,) + + # if not isinstance(variables, tuple): + # variables = (variables,) + + # def gen_formatted_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + + # values = tuple(gen_formatted_values()) + + # if len(values) == 1: + # values = tuple((values[0],) for _ in variables) + + # else: + # assert len(variables) == len(values), f'length of {variables} does not match length of {values}' + + # def gen_flattened_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield from value.reshape(-1) + + # elif isinstance(value, tuple): + # yield from value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield value + + # else: + # yield float(value) + + # values = tuple(gen_flattened_values()) + +# return EvalExprImpl( +# underlying=underlying, +# variables=variables, +# values=values, +# ) |