summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-05-09 09:46:20 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-05-09 09:46:20 +0200
commitc7a6424aff9a904b3452b42cba8c5422c147292c (patch)
treebe0bbed545c6faa2fea641eaa76f1b33512123da
parentreshape according to the number of rows of expressions (diff)
downloadpolymatrix-c7a6424aff9a904b3452b42cba8c5422c147292c.tar.gz
polymatrix-c7a6424aff9a904b3452b42cba8c5422c147292c.zip
save offset of a parameter with its name
-rw-r--r--polymatrix/expression/init/initevalexpr.py134
-rw-r--r--polymatrix/expression/init/initsubstituteexpr.py28
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py2
-rw-r--r--polymatrix/expression/utils/getvariableindices.py5
4 files changed, 2 insertions, 167 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index c59a8d5..29359d6 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -35,137 +35,3 @@ def init_eval_expr(
underlying=underlying,
substitutions=substitutions,
)
-
- # if values is not None:
- # if isinstance(variables, tuple):
- # if isinstance(values, tuple):
- # assert len(variables) == len(values), f'{variables=}, {values=}'
-
- # else:
- # values = tuple(values for _ in variables)
-
- # else:
- # variables = (variables,)
- # values = (values,)
-
- # subs = zip(variables, values)
-
- # elif isinstance(variables, dict):
- # subs = variables.items()
-
- # elif isinstance(variables, tuple):
- # subs = variables
-
- # else:
- # raise Exception(f'{variables=}')
-
- # def formatted_values(value):
- # if isinstance(value, np.ndarray):
- # return tuple(value.reshape(-1))
-
- # elif isinstance(value, tuple):
- # return value
-
- # elif isinstance(value, int) or isinstance(value, float):
- # return (value,)
-
- # else:
- # return (float(value),)
-
- # subs = tuple((var, formatted_values(val)) for var, val in subs)
-
- # def formatted_values(value):
- # # def gen_formatted_values():
- # # for value in values:
- # if isinstance(value, np.ndarray):
- # yield tuple(value.reshape(-1))
-
- # elif isinstance(value, tuple):
- # yield value
-
- # # elif isinstance(value, dict):
- # # for variable in variables:
- # # yield from value[variable]
-
- # elif isinstance(value, int) or isinstance(value, float):
- # yield (value,)
-
- # else:
- # yield (float(value),)
- # return tuple(gen_formatted_values())
-
-
- # if values is None:
- # if isinstance(variables, tuple):
- # variables, values = tuple(zip(*variables))
-
- # elif isinstance(variables, dict):
- # variables, values = tuple(zip(*variables.items()))
-
- # else:
- # raise Exception(f'unsupported case {variables=}')
-
- # elif isinstance(values, np.ndarray):
- # values = tuple(values.reshape(-1))
-
- # elif not isinstance(values, tuple):
- # values = (values,)
-
- # if not isinstance(variables, tuple):
- # variables = (variables,)
-
- # def gen_formatted_values():
- # for value in values:
- # if isinstance(value, np.ndarray):
- # yield tuple(value.reshape(-1))
-
- # elif isinstance(value, tuple):
- # yield value
-
- # elif isinstance(value, dict):
- # raise Exception('is this right?')
-
- # for variable in variables:
- # yield from value[variable]
-
- # elif isinstance(value, int) or isinstance(value, float):
- # yield (value,)
-
- # else:
- # yield (float(value),)
-
- # values = tuple(gen_formatted_values())
-
- # if len(values) == 1:
- # values = tuple((values[0],) for _ in variables)
-
- # else:
- # assert len(variables) == len(values), f'length of {variables} does not match length of {values}'
-
- # def gen_flattened_values():
- # for value in values:
- # if isinstance(value, np.ndarray):
- # yield from value.reshape(-1)
-
- # elif isinstance(value, tuple):
- # yield from value
-
- # elif isinstance(value, dict):
- # raise Exception('is this right?')
-
- # for variable in variables:
- # yield from value[variable]
-
- # elif isinstance(value, int) or isinstance(value, float):
- # yield value
-
- # else:
- # yield float(value)
-
- # values = tuple(gen_flattened_values())
-
-# return EvalExprImpl(
-# underlying=underlying,
-# variables=variables,
-# values=values,
-# )
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index 15a6566..e141d9e 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -73,31 +73,3 @@ def init_substitute_expr(
underlying=underlying,
substitutions=substitutions,
)
-
- # if values is None:
- # assert isinstance(variables, tuple)
-
- # if len(variables) == 0:
- # return underlying
-
- # variables, values = tuple(zip(*variables))
-
- # elif isinstance(values, np.ndarray):
- # values = tuple(values.reshape(-1))
-
- # elif not isinstance(values, tuple):
- # values = (values,)
-
- # def gen_substitutions():
- # for substitution in values:
- # match substitution:
- # case ExpressionBaseMixin():
- # yield substitution
- # case _:
- # yield init_from_sympy_expr(substitution)
-
- # return SubstituteExprImpl(
- # underlying=underlying,
- # variables=variables,
- # substitutions=tuple(gen_substitutions()),
- # )
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index c33ca5b..bb85387 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -40,7 +40,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
terms[row, 0] = {((var_index, 1),): 1.0}
state = state.register(
- key=self,
+ key=self.name,
n_param=underlying.shape[0],
)
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 61bae2a..4ee7037 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -29,16 +29,13 @@ def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
variable_indices = tuple(gen_variables_indices())
elif isinstance(variable, int):
- # raise Exception(f'{variable=}')
variable_indices = (variable,)
elif variable in state.offset_dict:
- # raise Exception(f'{variable=}')
- variable_indices = (state.offset_dict[variable][0],)
+ variable_indices = tuple(range(*state.offset_dict[variable]))
else:
variable_indices = None
- # raise Exception(f'variable index not found for {variable=}, {state.offset_dict=}')
return state, variable_indices