diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-05-09 09:46:20 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-05-09 09:46:20 +0200 |
commit | c7a6424aff9a904b3452b42cba8c5422c147292c (patch) | |
tree | be0bbed545c6faa2fea641eaa76f1b33512123da | |
parent | reshape according to the number of rows of expressions (diff) | |
download | polymatrix-c7a6424aff9a904b3452b42cba8c5422c147292c.tar.gz polymatrix-c7a6424aff9a904b3452b42cba8c5422c147292c.zip |
save offset of a parameter with its name
-rw-r--r-- | polymatrix/expression/init/initevalexpr.py | 134 | ||||
-rw-r--r-- | polymatrix/expression/init/initsubstituteexpr.py | 28 | ||||
-rw-r--r-- | polymatrix/expression/mixins/parametrizeexprmixin.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/utils/getvariableindices.py | 5 |
4 files changed, 2 insertions, 167 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index c59a8d5..29359d6 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -35,137 +35,3 @@ def init_eval_expr( underlying=underlying, substitutions=substitutions, ) - - # if values is not None: - # if isinstance(variables, tuple): - # if isinstance(values, tuple): - # assert len(variables) == len(values), f'{variables=}, {values=}' - - # else: - # values = tuple(values for _ in variables) - - # else: - # variables = (variables,) - # values = (values,) - - # subs = zip(variables, values) - - # elif isinstance(variables, dict): - # subs = variables.items() - - # elif isinstance(variables, tuple): - # subs = variables - - # else: - # raise Exception(f'{variables=}') - - # def formatted_values(value): - # if isinstance(value, np.ndarray): - # return tuple(value.reshape(-1)) - - # elif isinstance(value, tuple): - # return value - - # elif isinstance(value, int) or isinstance(value, float): - # return (value,) - - # else: - # return (float(value),) - - # subs = tuple((var, formatted_values(val)) for var, val in subs) - - # def formatted_values(value): - # # def gen_formatted_values(): - # # for value in values: - # if isinstance(value, np.ndarray): - # yield tuple(value.reshape(-1)) - - # elif isinstance(value, tuple): - # yield value - - # # elif isinstance(value, dict): - # # for variable in variables: - # # yield from value[variable] - - # elif isinstance(value, int) or isinstance(value, float): - # yield (value,) - - # else: - # yield (float(value),) - # return tuple(gen_formatted_values()) - - - # if values is None: - # if isinstance(variables, tuple): - # variables, values = tuple(zip(*variables)) - - # elif isinstance(variables, dict): - # variables, values = tuple(zip(*variables.items())) - - # else: - # raise Exception(f'unsupported case {variables=}') - - # elif isinstance(values, np.ndarray): - # values = tuple(values.reshape(-1)) - - # elif not isinstance(values, tuple): - # values = (values,) - - # if not isinstance(variables, tuple): - # variables = (variables,) - - # def gen_formatted_values(): - # for value in values: - # if isinstance(value, np.ndarray): - # yield tuple(value.reshape(-1)) - - # elif isinstance(value, tuple): - # yield value - - # elif isinstance(value, dict): - # raise Exception('is this right?') - - # for variable in variables: - # yield from value[variable] - - # elif isinstance(value, int) or isinstance(value, float): - # yield (value,) - - # else: - # yield (float(value),) - - # values = tuple(gen_formatted_values()) - - # if len(values) == 1: - # values = tuple((values[0],) for _ in variables) - - # else: - # assert len(variables) == len(values), f'length of {variables} does not match length of {values}' - - # def gen_flattened_values(): - # for value in values: - # if isinstance(value, np.ndarray): - # yield from value.reshape(-1) - - # elif isinstance(value, tuple): - # yield from value - - # elif isinstance(value, dict): - # raise Exception('is this right?') - - # for variable in variables: - # yield from value[variable] - - # elif isinstance(value, int) or isinstance(value, float): - # yield value - - # else: - # yield float(value) - - # values = tuple(gen_flattened_values()) - -# return EvalExprImpl( -# underlying=underlying, -# variables=variables, -# values=values, -# ) diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py index 15a6566..e141d9e 100644 --- a/polymatrix/expression/init/initsubstituteexpr.py +++ b/polymatrix/expression/init/initsubstituteexpr.py @@ -73,31 +73,3 @@ def init_substitute_expr( underlying=underlying, substitutions=substitutions, ) - - # if values is None: - # assert isinstance(variables, tuple) - - # if len(variables) == 0: - # return underlying - - # variables, values = tuple(zip(*variables)) - - # elif isinstance(values, np.ndarray): - # values = tuple(values.reshape(-1)) - - # elif not isinstance(values, tuple): - # values = (values,) - - # def gen_substitutions(): - # for substitution in values: - # match substitution: - # case ExpressionBaseMixin(): - # yield substitution - # case _: - # yield init_from_sympy_expr(substitution) - - # return SubstituteExprImpl( - # underlying=underlying, - # variables=variables, - # substitutions=tuple(gen_substitutions()), - # ) diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index c33ca5b..bb85387 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -40,7 +40,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin): terms[row, 0] = {((var_index, 1),): 1.0} state = state.register( - key=self, + key=self.name, n_param=underlying.shape[0], ) diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 61bae2a..4ee7037 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -29,16 +29,13 @@ def get_variable_indices_from_variable(state, variable) -> tuple[int] | None: variable_indices = tuple(gen_variables_indices()) elif isinstance(variable, int): - # raise Exception(f'{variable=}') variable_indices = (variable,) elif variable in state.offset_dict: - # raise Exception(f'{variable=}') - variable_indices = (state.offset_dict[variable][0],) + variable_indices = tuple(range(*state.offset_dict[variable])) else: variable_indices = None - # raise Exception(f'variable index not found for {variable=}, {state.offset_dict=}') return state, variable_indices |