summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/init
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/init')
-rw-r--r--polymatrix/expression/init/initderivativeexpr.py5
-rw-r--r--polymatrix/expression/init/initevalexpr.py182
-rw-r--r--polymatrix/expression/init/initfromsympyexpr.py6
-rw-r--r--polymatrix/expression/init/initlinearinexpr.py2
-rw-r--r--polymatrix/expression/init/initlinearmonomialsexpr.py7
-rw-r--r--polymatrix/expression/init/initquadraticinexpr.py7
-rw-r--r--polymatrix/expression/init/initquadraticmonomialsexpr.py7
-rw-r--r--polymatrix/expression/init/initsubstituteexpr.py107
8 files changed, 261 insertions, 62 deletions
diff --git a/polymatrix/expression/init/initderivativeexpr.py b/polymatrix/expression/init/initderivativeexpr.py
index c640f47..a6ca06c 100644
--- a/polymatrix/expression/init/initderivativeexpr.py
+++ b/polymatrix/expression/init/initderivativeexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.derivativeexprimpl import DerivativeExprImpl
def init_derivative_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
introduce_derivatives: bool = None,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
if introduce_derivatives is None:
introduce_derivatives = False
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 5bbd404..0fad8e5 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -1,55 +1,171 @@
+import typing
import numpy as np
+from polymatrix.expression.init.initsubstituteexpr import format_substitutions
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.evalexprimpl import EvalExprImpl
def init_eval_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
- values: tuple = None,
+ variables: typing.Union[typing.Any, tuple, dict],
+ values: typing.Union[float, tuple] = None,
):
- if values is None:
- if isinstance(variables, tuple):
- variables, values = tuple(zip(*variables))
+ substitutions = format_substitutions(
+ variables=variables,
+ values=values,
+ )
+
+ def formatted_values(value):
+ if isinstance(value, np.ndarray):
+ return tuple(value.reshape(-1))
+
+ elif isinstance(value, tuple):
+ return value
- elif isinstance(variables, dict):
- variables, values = tuple(zip(*variables.items()))
+ elif isinstance(value, int) or isinstance(value, float):
+ return (value,)
else:
- raise Exception(f'unsupported case {variables=}')
+ return (float(value),)
+
+ substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions)
+
+ return EvalExprImpl(
+ underlying=underlying,
+ substitutions=substitutions,
+ )
- elif isinstance(values, np.ndarray):
- values = tuple(values.reshape(-1))
+ # if values is not None:
+ # if isinstance(variables, tuple):
+ # if isinstance(values, tuple):
+ # assert len(variables) == len(values), f'{variables=}, {values=}'
- elif not isinstance(values, tuple):
- values = (values,)
+ # else:
+ # values = tuple(values for _ in variables)
- if not isinstance(variables, tuple):
- variables = (variables,)
+ # else:
+ # variables = (variables,)
+ # values = (values,)
- def gen_formatted_values():
- for value in values:
- if isinstance(value, np.ndarray):
- yield from value.reshape(-1)
+ # subs = zip(variables, values)
+
+ # elif isinstance(variables, dict):
+ # subs = variables.items()
- elif isinstance(value, tuple):
- yield from value
+ # elif isinstance(variables, tuple):
+ # subs = variables
- elif isinstance(value, dict):
- for variable in variables:
- yield from value[variable]
+ # else:
+ # raise Exception(f'{variables=}')
- elif isinstance(value, int) or isinstance(value, float):
- yield value
+ # def formatted_values(value):
+ # if isinstance(value, np.ndarray):
+ # return tuple(value.reshape(-1))
- else:
- yield float(value)
+ # elif isinstance(value, tuple):
+ # return value
- values = tuple(gen_formatted_values())
+ # elif isinstance(value, int) or isinstance(value, float):
+ # return (value,)
- return EvalExprImpl(
- underlying=underlying,
- variables=variables,
- values=values,
-)
+ # else:
+ # return (float(value),)
+
+ # subs = tuple((var, formatted_values(val)) for var, val in subs)
+
+ # def formatted_values(value):
+ # # def gen_formatted_values():
+ # # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # # elif isinstance(value, dict):
+ # # for variable in variables:
+ # # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+ # return tuple(gen_formatted_values())
+
+
+ # if values is None:
+ # if isinstance(variables, tuple):
+ # variables, values = tuple(zip(*variables))
+
+ # elif isinstance(variables, dict):
+ # variables, values = tuple(zip(*variables.items()))
+
+ # else:
+ # raise Exception(f'unsupported case {variables=}')
+
+ # elif isinstance(values, np.ndarray):
+ # values = tuple(values.reshape(-1))
+
+ # elif not isinstance(values, tuple):
+ # values = (values,)
+
+ # if not isinstance(variables, tuple):
+ # variables = (variables,)
+
+ # def gen_formatted_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+
+ # values = tuple(gen_formatted_values())
+
+ # if len(values) == 1:
+ # values = tuple((values[0],) for _ in variables)
+
+ # else:
+ # assert len(variables) == len(values), f'length of {variables} does not match length of {values}'
+
+ # def gen_flattened_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield from value.reshape(-1)
+
+ # elif isinstance(value, tuple):
+ # yield from value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield value
+
+ # else:
+ # yield float(value)
+
+ # values = tuple(gen_flattened_values())
+
+# return EvalExprImpl(
+# underlying=underlying,
+# variables=variables,
+# values=values,
+# )
diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py
index bb37f1d..3fb52f7 100644
--- a/polymatrix/expression/init/initfromsympyexpr.py
+++ b/polymatrix/expression/init/initfromsympyexpr.py
@@ -36,7 +36,13 @@ def init_from_sympy_expr(
case _:
data = tuple((e,) for e in data)
+ case np.number:
+ data = ((float(data),),)
+
case _:
+ if not isinstance(data, (float, int, sympy.Expr)):
+ raise Exception(f'{data=}, {type(data)=}')
+
data = ((data,),)
return FromSympyExprImpl(
diff --git a/polymatrix/expression/init/initlinearinexpr.py b/polymatrix/expression/init/initlinearinexpr.py
index 5cd172c..b869aee 100644
--- a/polymatrix/expression/init/initlinearinexpr.py
+++ b/polymatrix/expression/init/initlinearinexpr.py
@@ -8,6 +8,8 @@ def init_linear_in_expr(
variables: ExpressionBaseMixin,
ignore_unmatched: bool = None,
):
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return LinearInExprImpl(
underlying=underlying,
monomials=monomials,
diff --git a/polymatrix/expression/init/initlinearmonomialsexpr.py b/polymatrix/expression/init/initlinearmonomialsexpr.py
index f116562..8083715 100644
--- a/polymatrix/expression/init/initlinearmonomialsexpr.py
+++ b/polymatrix/expression/init/initlinearmonomialsexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.linearmonomialsexprimpl import LinearMonomialsEx
def init_linear_monomials_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return LinearMonomialsExprImpl(
underlying=underlying,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initquadraticinexpr.py b/polymatrix/expression/init/initquadraticinexpr.py
index 5aa40a5..6555b4b 100644
--- a/polymatrix/expression/init/initquadraticinexpr.py
+++ b/polymatrix/expression/init/initquadraticinexpr.py
@@ -5,10 +5,13 @@ from polymatrix.expression.impl.quadraticinexprimpl import QuadraticInExprImpl
def init_quadratic_in_expr(
underlying: ExpressionBaseMixin,
monomials: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return QuadraticInExprImpl(
underlying=underlying,
monomials=monomials,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initquadraticmonomialsexpr.py b/polymatrix/expression/init/initquadraticmonomialsexpr.py
index 8e46c62..190f7df 100644
--- a/polymatrix/expression/init/initquadraticmonomialsexpr.py
+++ b/polymatrix/expression/init/initquadraticmonomialsexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.quadraticmonomialsexprimpl import QuadraticMonom
def init_quadratic_monomials_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return QuadraticMonomialsExprImpl(
underlying=underlying,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index 403b169..50cbee0 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -1,3 +1,4 @@
+import typing
import numpy as np
from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr
@@ -5,35 +6,97 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl
+def format_substitutions(
+ variables: typing.Union[typing.Any, tuple, dict],
+ values: typing.Union[float, tuple] = None,
+):
+ """
+ (variables = x, values = 1.0) # ok
+ (variables = x, values = np.array(1.0)) # ok
+ (variables = (x, y, z), values = 1.0) # ok
+ (variables = (x, y, z), values = (1.0, 2.0, 3.0)) # ok
+ (variables = {x: 1.0, y: 2.0, z: 3.0}) # ok
+ (variables = ((x, 1.0), (y, 2.0), (z, 3.0))) # ok
+
+ (variables = v, values = (1.0, 2.0)) # ok
+ (variables = (v1, v2), values = ((1.0, 2.0), (3.0,))) # ok
+ (variables = (v1, v2), values = (1.0, 2.0, 3.0)) # not ok
+ """
+
+ if values is not None:
+ if isinstance(variables, tuple):
+ if isinstance(values, tuple):
+ assert len(variables) == len(values), f'{variables=}, {values=}'
+
+ else:
+ values = tuple(values for _ in variables)
+
+ else:
+ variables = (variables,)
+ values = (values,)
+
+ substitutions = zip(variables, values)
+
+ elif isinstance(variables, dict):
+ substitutions = variables.items()
+
+ elif isinstance(variables, tuple):
+ substitutions = variables
+
+ else:
+ raise Exception(f'{variables=}')
+
+ return substitutions
+
+
def init_substitute_expr(
underlying: ExpressionBaseMixin,
variables: tuple,
- substitutions: tuple = None,
+ values: tuple = None,
):
- if substitutions is None:
- assert isinstance(variables, tuple)
-
- if len(variables) == 0:
- return underlying
-
- variables, substitutions = tuple(zip(*variables))
- elif isinstance(substitutions, np.ndarray):
- substitutions = tuple(substitutions.reshape(-1))
+ substitutions = format_substitutions(
+ variables=variables,
+ values=values,
+ )
- elif not isinstance(substitutions, tuple):
- substitutions = (substitutions,)
+ def formatted_values(value) -> ExpressionBaseMixin:
+ if isinstance(value, ExpressionBaseMixin):
+ return value
+ else:
+ return init_from_sympy_expr(value)
- def gen_substitutions():
- for substitution in substitutions:
- match substitution:
- case ExpressionBaseMixin():
- yield substitution
- case _:
- yield init_from_sympy_expr(substitution)
+ substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions)
return SubstituteExprImpl(
underlying=underlying,
- variables=variables,
- substitutions=tuple(gen_substitutions()),
-)
+ substitutions=substitutions,
+ )
+
+ # if values is None:
+ # assert isinstance(variables, tuple)
+
+ # if len(variables) == 0:
+ # return underlying
+
+ # variables, values = tuple(zip(*variables))
+
+ # elif isinstance(values, np.ndarray):
+ # values = tuple(values.reshape(-1))
+
+ # elif not isinstance(values, tuple):
+ # values = (values,)
+
+ # def gen_substitutions():
+ # for substitution in values:
+ # match substitution:
+ # case ExpressionBaseMixin():
+ # yield substitution
+ # case _:
+ # yield init_from_sympy_expr(substitution)
+
+ # return SubstituteExprImpl(
+ # underlying=underlying,
+ # variables=variables,
+ # substitutions=tuple(gen_substitutions()),
+ # )