diff options
Diffstat (limited to 'polymatrix/expression/init')
-rw-r--r-- | polymatrix/expression/init/initderivativeexpr.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/init/initevalexpr.py | 182 | ||||
-rw-r--r-- | polymatrix/expression/init/initfromsympyexpr.py | 6 | ||||
-rw-r--r-- | polymatrix/expression/init/initlinearinexpr.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init/initlinearmonomialsexpr.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init/initquadraticinexpr.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init/initquadraticmonomialsexpr.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init/initsubstituteexpr.py | 107 |
8 files changed, 261 insertions, 62 deletions
diff --git a/polymatrix/expression/init/initderivativeexpr.py b/polymatrix/expression/init/initderivativeexpr.py index c640f47..a6ca06c 100644 --- a/polymatrix/expression/init/initderivativeexpr.py +++ b/polymatrix/expression/init/initderivativeexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.derivativeexprimpl import DerivativeExprImpl def init_derivative_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, introduce_derivatives: bool = None, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + if introduce_derivatives is None: introduce_derivatives = False diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 5bbd404..0fad8e5 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -1,55 +1,171 @@ +import typing import numpy as np +from polymatrix.expression.init.initsubstituteexpr import format_substitutions from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.evalexprimpl import EvalExprImpl def init_eval_expr( underlying: ExpressionBaseMixin, - variables: tuple, - values: tuple = None, + variables: typing.Union[typing.Any, tuple, dict], + values: typing.Union[float, tuple] = None, ): - if values is None: - if isinstance(variables, tuple): - variables, values = tuple(zip(*variables)) + substitutions = format_substitutions( + variables=variables, + values=values, + ) + + def formatted_values(value): + if isinstance(value, np.ndarray): + return tuple(value.reshape(-1)) + + elif isinstance(value, tuple): + return value - elif isinstance(variables, dict): - variables, values = tuple(zip(*variables.items())) + elif isinstance(value, int) or isinstance(value, float): + return (value,) else: - raise Exception(f'unsupported case {variables=}') + return (float(value),) + + substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions) + + return EvalExprImpl( + underlying=underlying, + substitutions=substitutions, + ) - elif isinstance(values, np.ndarray): - values = tuple(values.reshape(-1)) + # if values is not None: + # if isinstance(variables, tuple): + # if isinstance(values, tuple): + # assert len(variables) == len(values), f'{variables=}, {values=}' - elif not isinstance(values, tuple): - values = (values,) + # else: + # values = tuple(values for _ in variables) - if not isinstance(variables, tuple): - variables = (variables,) + # else: + # variables = (variables,) + # values = (values,) - def gen_formatted_values(): - for value in values: - if isinstance(value, np.ndarray): - yield from value.reshape(-1) + # subs = zip(variables, values) + + # elif isinstance(variables, dict): + # subs = variables.items() - elif isinstance(value, tuple): - yield from value + # elif isinstance(variables, tuple): + # subs = variables - elif isinstance(value, dict): - for variable in variables: - yield from value[variable] + # else: + # raise Exception(f'{variables=}') - elif isinstance(value, int) or isinstance(value, float): - yield value + # def formatted_values(value): + # if isinstance(value, np.ndarray): + # return tuple(value.reshape(-1)) - else: - yield float(value) + # elif isinstance(value, tuple): + # return value - values = tuple(gen_formatted_values()) + # elif isinstance(value, int) or isinstance(value, float): + # return (value,) - return EvalExprImpl( - underlying=underlying, - variables=variables, - values=values, -) + # else: + # return (float(value),) + + # subs = tuple((var, formatted_values(val)) for var, val in subs) + + # def formatted_values(value): + # # def gen_formatted_values(): + # # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # # elif isinstance(value, dict): + # # for variable in variables: + # # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + # return tuple(gen_formatted_values()) + + + # if values is None: + # if isinstance(variables, tuple): + # variables, values = tuple(zip(*variables)) + + # elif isinstance(variables, dict): + # variables, values = tuple(zip(*variables.items())) + + # else: + # raise Exception(f'unsupported case {variables=}') + + # elif isinstance(values, np.ndarray): + # values = tuple(values.reshape(-1)) + + # elif not isinstance(values, tuple): + # values = (values,) + + # if not isinstance(variables, tuple): + # variables = (variables,) + + # def gen_formatted_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + + # values = tuple(gen_formatted_values()) + + # if len(values) == 1: + # values = tuple((values[0],) for _ in variables) + + # else: + # assert len(variables) == len(values), f'length of {variables} does not match length of {values}' + + # def gen_flattened_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield from value.reshape(-1) + + # elif isinstance(value, tuple): + # yield from value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield value + + # else: + # yield float(value) + + # values = tuple(gen_flattened_values()) + +# return EvalExprImpl( +# underlying=underlying, +# variables=variables, +# values=values, +# ) diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py index bb37f1d..3fb52f7 100644 --- a/polymatrix/expression/init/initfromsympyexpr.py +++ b/polymatrix/expression/init/initfromsympyexpr.py @@ -36,7 +36,13 @@ def init_from_sympy_expr( case _: data = tuple((e,) for e in data) + case np.number: + data = ((float(data),),) + case _: + if not isinstance(data, (float, int, sympy.Expr)): + raise Exception(f'{data=}, {type(data)=}') + data = ((data,),) return FromSympyExprImpl( diff --git a/polymatrix/expression/init/initlinearinexpr.py b/polymatrix/expression/init/initlinearinexpr.py index 5cd172c..b869aee 100644 --- a/polymatrix/expression/init/initlinearinexpr.py +++ b/polymatrix/expression/init/initlinearinexpr.py @@ -8,6 +8,8 @@ def init_linear_in_expr( variables: ExpressionBaseMixin, ignore_unmatched: bool = None, ): + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return LinearInExprImpl( underlying=underlying, monomials=monomials, diff --git a/polymatrix/expression/init/initlinearmonomialsexpr.py b/polymatrix/expression/init/initlinearmonomialsexpr.py index f116562..8083715 100644 --- a/polymatrix/expression/init/initlinearmonomialsexpr.py +++ b/polymatrix/expression/init/initlinearmonomialsexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.linearmonomialsexprimpl import LinearMonomialsEx def init_linear_monomials_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return LinearMonomialsExprImpl( underlying=underlying, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initquadraticinexpr.py b/polymatrix/expression/init/initquadraticinexpr.py index 5aa40a5..6555b4b 100644 --- a/polymatrix/expression/init/initquadraticinexpr.py +++ b/polymatrix/expression/init/initquadraticinexpr.py @@ -5,10 +5,13 @@ from polymatrix.expression.impl.quadraticinexprimpl import QuadraticInExprImpl def init_quadratic_in_expr( underlying: ExpressionBaseMixin, monomials: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return QuadraticInExprImpl( underlying=underlying, monomials=monomials, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initquadraticmonomialsexpr.py b/polymatrix/expression/init/initquadraticmonomialsexpr.py index 8e46c62..190f7df 100644 --- a/polymatrix/expression/init/initquadraticmonomialsexpr.py +++ b/polymatrix/expression/init/initquadraticmonomialsexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.quadraticmonomialsexprimpl import QuadraticMonom def init_quadratic_monomials_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return QuadraticMonomialsExprImpl( underlying=underlying, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py index 403b169..50cbee0 100644 --- a/polymatrix/expression/init/initsubstituteexpr.py +++ b/polymatrix/expression/init/initsubstituteexpr.py @@ -1,3 +1,4 @@ +import typing import numpy as np from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr @@ -5,35 +6,97 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl +def format_substitutions( + variables: typing.Union[typing.Any, tuple, dict], + values: typing.Union[float, tuple] = None, +): + """ + (variables = x, values = 1.0) # ok + (variables = x, values = np.array(1.0)) # ok + (variables = (x, y, z), values = 1.0) # ok + (variables = (x, y, z), values = (1.0, 2.0, 3.0)) # ok + (variables = {x: 1.0, y: 2.0, z: 3.0}) # ok + (variables = ((x, 1.0), (y, 2.0), (z, 3.0))) # ok + + (variables = v, values = (1.0, 2.0)) # ok + (variables = (v1, v2), values = ((1.0, 2.0), (3.0,))) # ok + (variables = (v1, v2), values = (1.0, 2.0, 3.0)) # not ok + """ + + if values is not None: + if isinstance(variables, tuple): + if isinstance(values, tuple): + assert len(variables) == len(values), f'{variables=}, {values=}' + + else: + values = tuple(values for _ in variables) + + else: + variables = (variables,) + values = (values,) + + substitutions = zip(variables, values) + + elif isinstance(variables, dict): + substitutions = variables.items() + + elif isinstance(variables, tuple): + substitutions = variables + + else: + raise Exception(f'{variables=}') + + return substitutions + + def init_substitute_expr( underlying: ExpressionBaseMixin, variables: tuple, - substitutions: tuple = None, + values: tuple = None, ): - if substitutions is None: - assert isinstance(variables, tuple) - - if len(variables) == 0: - return underlying - - variables, substitutions = tuple(zip(*variables)) - elif isinstance(substitutions, np.ndarray): - substitutions = tuple(substitutions.reshape(-1)) + substitutions = format_substitutions( + variables=variables, + values=values, + ) - elif not isinstance(substitutions, tuple): - substitutions = (substitutions,) + def formatted_values(value) -> ExpressionBaseMixin: + if isinstance(value, ExpressionBaseMixin): + return value + else: + return init_from_sympy_expr(value) - def gen_substitutions(): - for substitution in substitutions: - match substitution: - case ExpressionBaseMixin(): - yield substitution - case _: - yield init_from_sympy_expr(substitution) + substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions) return SubstituteExprImpl( underlying=underlying, - variables=variables, - substitutions=tuple(gen_substitutions()), -) + substitutions=substitutions, + ) + + # if values is None: + # assert isinstance(variables, tuple) + + # if len(variables) == 0: + # return underlying + + # variables, values = tuple(zip(*variables)) + + # elif isinstance(values, np.ndarray): + # values = tuple(values.reshape(-1)) + + # elif not isinstance(values, tuple): + # values = (values,) + + # def gen_substitutions(): + # for substitution in values: + # match substitution: + # case ExpressionBaseMixin(): + # yield substitution + # case _: + # yield init_from_sympy_expr(substitution) + + # return SubstituteExprImpl( + # underlying=underlying, + # variables=variables, + # substitutions=tuple(gen_substitutions()), + # ) |