diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-01-30 16:19:24 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2023-01-30 16:19:24 +0100 |
commit | 09175e1f03dc260f743de28d47a16d2f5a97bf38 (patch) | |
tree | 09320967339e60467c3f6f2a2fb6d999f112e8b5 | |
parent | update README (diff) | |
download | polymatrix-09175e1f03dc260f743de28d47a16d2f5a97bf38.tar.gz polymatrix-09175e1f03dc260f743de28d47a16d2f5a97bf38.zip |
bugfix in eval/substitution operator
27 files changed, 559 insertions, 222 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 0c15ae3..e6cf4a8 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -17,7 +17,7 @@ from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr from polymatrix.expression.init.initvstackexpr import init_v_stack_expr from polymatrix.polymatrix.polymatrix import PolyMatrix -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices, get_variable_indices_from_variable from polymatrix.statemonad.init.initstatemonad import init_state_monad from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin from polymatrix.expression.utils.monomialtoindex import monomial_to_index @@ -363,7 +363,7 @@ class MatrixRepresentations: return dict(gen_matrices()) def get_value(self, variable, value): - variable_indices = get_variable_indices(self.state, variable)[1] + variable_indices = get_variable_indices_from_variable(self.state, variable)[1] def gen_value_index(): for variable_index in variable_indices: @@ -377,7 +377,7 @@ class MatrixRepresentations: return value[value_index] def set_value(self, variable, value): - variable_indices = get_variable_indices(self.state, variable)[1] + variable_indices = get_variable_indices_from_variable(self.state, variable)[1] value_index = list(self.variable_mapping.index(variable_index) for variable_index in variable_indices) vec = np.zeros(len(self.variable_mapping)) vec[value_index] = value @@ -430,13 +430,15 @@ class MatrixRepresentations: return func def to_matrix_repr( - expressions: tuple[Expression], + expressions: Expression | tuple[Expression], variables: Expression, ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: if isinstance(expressions, Expression): expressions = (expressions,) + assert isinstance(variables, Expression), f'{variables=}' + def func(state: ExpressionState): def acc_underlying_application(acc, v): @@ -454,39 +456,7 @@ def to_matrix_repr( initial=(state, tuple()), )) - # if variables is None: - - # def gen_used_variables(): - # def gen_used_auxillary_variables(considered): - # monomial_terms = state.auxillary_equations[considered[-1]] - # for monomial in monomial_terms.keys(): - # for variable in monomial: - # yield variable - - # if variable not in considered and variable in state.auxillary_equations: - # yield from gen_used_auxillary_variables(considered + (variable,)) - - # for underlying in underlying_list: - # for row in range(underlying.shape[0]): - # for col in range(underlying.shape[1]): - - # try: - # underlying_terms = underlying.get_poly(row, col) - # except KeyError: - # continue - - # for monomial in underlying_terms.keys(): - # for variable, _ in monomial: - # yield variable - - # if variable in state.auxillary_equations: - # yield from gen_used_auxillary_variables((variable,)) - - # ordered_variable_index = tuple(sorted(set(gen_used_variables()))) - - # else: - - state, ordered_variable_index = get_variable_indices(state, variables) + state, ordered_variable_index = get_variable_indices_from_variable(state, variables) assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' @@ -534,8 +504,6 @@ def to_matrix_repr( underlying_matrices = tuple(gen_underlying_matrices()) - # current_row = underlying.shape[0] - def gen_auxillary_equations(): for key, monomial_terms in state.auxillary_equations.items(): if key in ordered_variable_index: @@ -581,6 +549,7 @@ def to_matrix_repr( def to_constant_repr( expr: Expression, + assert_constant: bool = True, ) -> StateMonadMixin[ExpressionState, np.ndarray]: def func(state: ExpressionState): @@ -592,6 +561,9 @@ def to_constant_repr( for monomial, value in polynomial.items(): if len(monomial) == 0: A[row, col] = value + + elif assert_constant: + raise Exception(f'non-constant term {monomial=}') return state, A @@ -605,7 +577,7 @@ def degrees( def func(state: ExpressionState): state, underlying = expr.apply(state) - state, variable_indices = get_variable_indices(state, variables) + state, variable_indices = get_variable_indices_from_variable(state, variables) def gen_rows(): for row in range(underlying.shape[0]): diff --git a/polymatrix/expression/impl/__init__.py b/polymatrix/expression/impl/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/polymatrix/expression/impl/__init__.py diff --git a/polymatrix/expression/impl/evalexprimpl.py b/polymatrix/expression/impl/evalexprimpl.py index cd97155..7f13a78 100644 --- a/polymatrix/expression/impl/evalexprimpl.py +++ b/polymatrix/expression/impl/evalexprimpl.py @@ -6,5 +6,5 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @dataclass_abc.dataclass_abc(frozen=True) class EvalExprImpl(EvalExpr): underlying: ExpressionBaseMixin - variables: tuple - values: tuple + substitutions: tuple + # values: tuple diff --git a/polymatrix/expression/impl/substituteexprimpl.py b/polymatrix/expression/impl/substituteexprimpl.py index 1eae940..086f17e 100644 --- a/polymatrix/expression/impl/substituteexprimpl.py +++ b/polymatrix/expression/impl/substituteexprimpl.py @@ -6,5 +6,5 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @dataclass_abc.dataclass_abc(frozen=True) class SubstituteExprImpl(SubstituteExpr): underlying: ExpressionBaseMixin - variables: tuple + # variables: tuple substitutions: tuple diff --git a/polymatrix/expression/init/initderivativeexpr.py b/polymatrix/expression/init/initderivativeexpr.py index c640f47..a6ca06c 100644 --- a/polymatrix/expression/init/initderivativeexpr.py +++ b/polymatrix/expression/init/initderivativeexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.derivativeexprimpl import DerivativeExprImpl def init_derivative_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, introduce_derivatives: bool = None, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + if introduce_derivatives is None: introduce_derivatives = False diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 5bbd404..0fad8e5 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -1,55 +1,171 @@ +import typing import numpy as np +from polymatrix.expression.init.initsubstituteexpr import format_substitutions from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.evalexprimpl import EvalExprImpl def init_eval_expr( underlying: ExpressionBaseMixin, - variables: tuple, - values: tuple = None, + variables: typing.Union[typing.Any, tuple, dict], + values: typing.Union[float, tuple] = None, ): - if values is None: - if isinstance(variables, tuple): - variables, values = tuple(zip(*variables)) + substitutions = format_substitutions( + variables=variables, + values=values, + ) + + def formatted_values(value): + if isinstance(value, np.ndarray): + return tuple(value.reshape(-1)) + + elif isinstance(value, tuple): + return value - elif isinstance(variables, dict): - variables, values = tuple(zip(*variables.items())) + elif isinstance(value, int) or isinstance(value, float): + return (value,) else: - raise Exception(f'unsupported case {variables=}') + return (float(value),) + + substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions) + + return EvalExprImpl( + underlying=underlying, + substitutions=substitutions, + ) - elif isinstance(values, np.ndarray): - values = tuple(values.reshape(-1)) + # if values is not None: + # if isinstance(variables, tuple): + # if isinstance(values, tuple): + # assert len(variables) == len(values), f'{variables=}, {values=}' - elif not isinstance(values, tuple): - values = (values,) + # else: + # values = tuple(values for _ in variables) - if not isinstance(variables, tuple): - variables = (variables,) + # else: + # variables = (variables,) + # values = (values,) - def gen_formatted_values(): - for value in values: - if isinstance(value, np.ndarray): - yield from value.reshape(-1) + # subs = zip(variables, values) + + # elif isinstance(variables, dict): + # subs = variables.items() - elif isinstance(value, tuple): - yield from value + # elif isinstance(variables, tuple): + # subs = variables - elif isinstance(value, dict): - for variable in variables: - yield from value[variable] + # else: + # raise Exception(f'{variables=}') - elif isinstance(value, int) or isinstance(value, float): - yield value + # def formatted_values(value): + # if isinstance(value, np.ndarray): + # return tuple(value.reshape(-1)) - else: - yield float(value) + # elif isinstance(value, tuple): + # return value - values = tuple(gen_formatted_values()) + # elif isinstance(value, int) or isinstance(value, float): + # return (value,) - return EvalExprImpl( - underlying=underlying, - variables=variables, - values=values, -) + # else: + # return (float(value),) + + # subs = tuple((var, formatted_values(val)) for var, val in subs) + + # def formatted_values(value): + # # def gen_formatted_values(): + # # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # # elif isinstance(value, dict): + # # for variable in variables: + # # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + # return tuple(gen_formatted_values()) + + + # if values is None: + # if isinstance(variables, tuple): + # variables, values = tuple(zip(*variables)) + + # elif isinstance(variables, dict): + # variables, values = tuple(zip(*variables.items())) + + # else: + # raise Exception(f'unsupported case {variables=}') + + # elif isinstance(values, np.ndarray): + # values = tuple(values.reshape(-1)) + + # elif not isinstance(values, tuple): + # values = (values,) + + # if not isinstance(variables, tuple): + # variables = (variables,) + + # def gen_formatted_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield tuple(value.reshape(-1)) + + # elif isinstance(value, tuple): + # yield value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield (value,) + + # else: + # yield (float(value),) + + # values = tuple(gen_formatted_values()) + + # if len(values) == 1: + # values = tuple((values[0],) for _ in variables) + + # else: + # assert len(variables) == len(values), f'length of {variables} does not match length of {values}' + + # def gen_flattened_values(): + # for value in values: + # if isinstance(value, np.ndarray): + # yield from value.reshape(-1) + + # elif isinstance(value, tuple): + # yield from value + + # elif isinstance(value, dict): + # raise Exception('is this right?') + + # for variable in variables: + # yield from value[variable] + + # elif isinstance(value, int) or isinstance(value, float): + # yield value + + # else: + # yield float(value) + + # values = tuple(gen_flattened_values()) + +# return EvalExprImpl( +# underlying=underlying, +# variables=variables, +# values=values, +# ) diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py index bb37f1d..3fb52f7 100644 --- a/polymatrix/expression/init/initfromsympyexpr.py +++ b/polymatrix/expression/init/initfromsympyexpr.py @@ -36,7 +36,13 @@ def init_from_sympy_expr( case _: data = tuple((e,) for e in data) + case np.number: + data = ((float(data),),) + case _: + if not isinstance(data, (float, int, sympy.Expr)): + raise Exception(f'{data=}, {type(data)=}') + data = ((data,),) return FromSympyExprImpl( diff --git a/polymatrix/expression/init/initlinearinexpr.py b/polymatrix/expression/init/initlinearinexpr.py index 5cd172c..b869aee 100644 --- a/polymatrix/expression/init/initlinearinexpr.py +++ b/polymatrix/expression/init/initlinearinexpr.py @@ -8,6 +8,8 @@ def init_linear_in_expr( variables: ExpressionBaseMixin, ignore_unmatched: bool = None, ): + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return LinearInExprImpl( underlying=underlying, monomials=monomials, diff --git a/polymatrix/expression/init/initlinearmonomialsexpr.py b/polymatrix/expression/init/initlinearmonomialsexpr.py index f116562..8083715 100644 --- a/polymatrix/expression/init/initlinearmonomialsexpr.py +++ b/polymatrix/expression/init/initlinearmonomialsexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.linearmonomialsexprimpl import LinearMonomialsEx def init_linear_monomials_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return LinearMonomialsExprImpl( underlying=underlying, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initquadraticinexpr.py b/polymatrix/expression/init/initquadraticinexpr.py index 5aa40a5..6555b4b 100644 --- a/polymatrix/expression/init/initquadraticinexpr.py +++ b/polymatrix/expression/init/initquadraticinexpr.py @@ -5,10 +5,13 @@ from polymatrix.expression.impl.quadraticinexprimpl import QuadraticInExprImpl def init_quadratic_in_expr( underlying: ExpressionBaseMixin, monomials: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return QuadraticInExprImpl( underlying=underlying, monomials=monomials, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initquadraticmonomialsexpr.py b/polymatrix/expression/init/initquadraticmonomialsexpr.py index 8e46c62..190f7df 100644 --- a/polymatrix/expression/init/initquadraticmonomialsexpr.py +++ b/polymatrix/expression/init/initquadraticmonomialsexpr.py @@ -4,9 +4,12 @@ from polymatrix.expression.impl.quadraticmonomialsexprimpl import QuadraticMonom def init_quadratic_monomials_expr( underlying: ExpressionBaseMixin, - variables: tuple, + variables: ExpressionBaseMixin, ): + + assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + return QuadraticMonomialsExprImpl( underlying=underlying, variables=variables, -) + ) diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py index 403b169..50cbee0 100644 --- a/polymatrix/expression/init/initsubstituteexpr.py +++ b/polymatrix/expression/init/initsubstituteexpr.py @@ -1,3 +1,4 @@ +import typing import numpy as np from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr @@ -5,35 +6,97 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl +def format_substitutions( + variables: typing.Union[typing.Any, tuple, dict], + values: typing.Union[float, tuple] = None, +): + """ + (variables = x, values = 1.0) # ok + (variables = x, values = np.array(1.0)) # ok + (variables = (x, y, z), values = 1.0) # ok + (variables = (x, y, z), values = (1.0, 2.0, 3.0)) # ok + (variables = {x: 1.0, y: 2.0, z: 3.0}) # ok + (variables = ((x, 1.0), (y, 2.0), (z, 3.0))) # ok + + (variables = v, values = (1.0, 2.0)) # ok + (variables = (v1, v2), values = ((1.0, 2.0), (3.0,))) # ok + (variables = (v1, v2), values = (1.0, 2.0, 3.0)) # not ok + """ + + if values is not None: + if isinstance(variables, tuple): + if isinstance(values, tuple): + assert len(variables) == len(values), f'{variables=}, {values=}' + + else: + values = tuple(values for _ in variables) + + else: + variables = (variables,) + values = (values,) + + substitutions = zip(variables, values) + + elif isinstance(variables, dict): + substitutions = variables.items() + + elif isinstance(variables, tuple): + substitutions = variables + + else: + raise Exception(f'{variables=}') + + return substitutions + + def init_substitute_expr( underlying: ExpressionBaseMixin, variables: tuple, - substitutions: tuple = None, + values: tuple = None, ): - if substitutions is None: - assert isinstance(variables, tuple) - - if len(variables) == 0: - return underlying - - variables, substitutions = tuple(zip(*variables)) - elif isinstance(substitutions, np.ndarray): - substitutions = tuple(substitutions.reshape(-1)) + substitutions = format_substitutions( + variables=variables, + values=values, + ) - elif not isinstance(substitutions, tuple): - substitutions = (substitutions,) + def formatted_values(value) -> ExpressionBaseMixin: + if isinstance(value, ExpressionBaseMixin): + return value + else: + return init_from_sympy_expr(value) - def gen_substitutions(): - for substitution in substitutions: - match substitution: - case ExpressionBaseMixin(): - yield substitution - case _: - yield init_from_sympy_expr(substitution) + substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions) return SubstituteExprImpl( underlying=underlying, - variables=variables, - substitutions=tuple(gen_substitutions()), -) + substitutions=substitutions, + ) + + # if values is None: + # assert isinstance(variables, tuple) + + # if len(variables) == 0: + # return underlying + + # variables, values = tuple(zip(*variables)) + + # elif isinstance(values, np.ndarray): + # values = tuple(values.reshape(-1)) + + # elif not isinstance(values, tuple): + # values = (values,) + + # def gen_substitutions(): + # for substitution in values: + # match substitution: + # case ExpressionBaseMixin(): + # yield substitution + # case _: + # yield init_from_sympy_expr(substitution) + + # return SubstituteExprImpl( + # underlying=underlying, + # variables=variables, + # substitutions=tuple(gen_substitutions()), + # ) diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 5fce215..183608b 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class DerivativeExprMixin(ExpressionBaseMixin): @@ -18,7 +18,7 @@ class DerivativeExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]: + def variables(self) -> ExpressionBaseMixin: ... @property @@ -33,7 +33,7 @@ class DerivativeExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, diff_wrt_variables = get_variable_indices(state, self.variables) + state, diff_wrt_variables = get_variable_indices_from_variable(state, self.variables) assert underlying.shape[1] == 1, f'{underlying.shape=}' diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 3002109..8ada607 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class DivergenceExprMixin(ExpressionBaseMixin): @@ -29,7 +29,7 @@ class DivergenceExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variables = get_variable_indices(state, self.variables) + state, variables = get_variable_indices_from_variable(state, self.variables) assert underlying.shape[1] == 1, f'{underlying.shape=}' assert len(variables) == underlying.shape[0], f'{variables=}, {underlying.shape=}' diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 2358566..cbf4ce8 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class EvalExprMixin(ExpressionBaseMixin): @@ -18,12 +18,7 @@ class EvalExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: - ... - - @property - @abc.abstractmethod - def values(self) -> tuple[float, ...]: + def substitutions(self) -> tuple: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -31,17 +26,32 @@ class EvalExprMixin(ExpressionBaseMixin): self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variables) - if len(self.values) == 1: - values = tuple(self.values[0] for _ in variable_indices) + def acc_variable_indices_and_values(acc, next): + state, acc_indices, acc_values = acc + variable, values = next - else: - assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}' + state, indices = get_variable_indices_from_variable(state, variable) - values = self.values + if indices is None: + return acc + if len(values) == 1: + values = tuple(values[0] for _ in indices) + + else: + assert len(indices) == len(values), f'{variable=}, {indices=} ({len(indices)}), {values=} ({len(values)})' + + return state, acc_indices + indices, acc_values + values + + *_, (state, variable_indices, values) = itertools.accumulate( + self.substitutions, + acc_variable_indices_and_values, + initial=(state, tuple(), tuple()) + ) + terms = {} for row in range(underlying.shape[0]): @@ -91,3 +101,27 @@ class EvalExprMixin(ExpressionBaseMixin): ) return state, poly_matrix + + + # if len(self.values) == 1: + # values = tuple(self.values[0] for _ in self.variables) + + # else: + # values = self.values + + # def filter_valid_variables(): + # for var, val in zip(self.variables, self.values): + # if isinstance(var, ExpressionBaseMixin) or isinstance(var, int) or (var in state.offset_dict): + # yield var, val + + # variables, values = zip(*filter_valid_variables()) + + # state, variable_indices = get_variable_indices(state, self.variables) + + # if len(self.values) == 1: + # values = tuple(self.values[0] for _ in variable_indices) + + # else: + # assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}' + + # values = self.values diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index 5602294..b828d6a 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -2,6 +2,7 @@ import abc import dataclasses import typing import numpy as np +import sympy from polymatrix.expression.init.initadditionexpr import init_addition_expr from polymatrix.expression.init.initcacheexpr import init_cache_expr @@ -27,6 +28,7 @@ from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr from polymatrix.expression.init.initreshapeexpr import init_reshape_expr from polymatrix.expression.init.initsetelementatexpr import init_set_element_at_expr from polymatrix.expression.init.initquadraticmonomialsexpr import init_quadratic_monomials_expr +from polymatrix.expression.init.initsqueezeexpr import init_squeeze_expr from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr from polymatrix.expression.init.initsumexpr import init_sum_expr @@ -60,11 +62,7 @@ class ExpressionMixin( if other is None: return self - match other: - case ExpressionBaseMixin(): - right = other - case _: - right = init_from_sympy_expr(other) + right = self._convert_to_expression(other) return dataclasses.replace( self, @@ -96,11 +94,7 @@ class ExpressionMixin( ) def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin': - match other: - case ExpressionBaseMixin(): - right = other - case _: - right = init_from_sympy_expr(other) + right = self._convert_to_expression(other) return dataclasses.replace( self, @@ -111,13 +105,7 @@ class ExpressionMixin( ) def __mul__(self, other) -> 'ExpressionMixin': - # assert isinstance(other, float) - - match other: - case ExpressionBaseMixin(): - right = other - case _: - right = init_from_sympy_expr(other) + right = self._convert_to_expression(other) return dataclasses.replace( self, @@ -127,6 +115,14 @@ class ExpressionMixin( ), ) + def __pow__(self, num): + curr = 1 + + for _ in range(num): + curr = curr * self + + return curr + def __neg__(self): return self * (-1) @@ -134,9 +130,9 @@ class ExpressionMixin( return self + other def __rmatmul__(self, other): - other = init_from_sympy_expr(other) + left = self._convert_to_expression(other) - return other @ self + return left @ self def __rmul__(self, other): return self * other @@ -145,11 +141,7 @@ class ExpressionMixin( return self + other * (-1) def __truediv__(self, other: ExpressionBaseMixin): - match other: - case ExpressionBaseMixin(): - right = other - case _: - right = init_from_sympy_expr(other) + right = self._convert_to_expression(other) return dataclasses.replace( self, @@ -159,6 +151,18 @@ class ExpressionMixin( ), ) + def _convert_to_expression(self, other): + if isinstance(other, ExpressionBaseMixin): + return other + + # can_convert = isinstance(other, (float, int, sympy.Expr, np.ndarray)) + + # if not can_convert: + # raise Exception(f'{other} cannot be converted to an Expression') + # else: + + return init_from_sympy_expr(other) + def cache(self) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -371,6 +375,16 @@ class ExpressionMixin( ), ) + def squeeze( + self, + ) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_squeeze_expr( + underlying=self.underlying, + ), + ) + def subtract_monomials( self, monomials: 'ExpressionMixin', @@ -386,25 +400,25 @@ class ExpressionMixin( def substitute( self, variable: tuple, - substitutions: tuple['ExpressionMixin', ...] = None, + values: tuple['ExpressionMixin', ...] = None, ) -> 'ExpressionMixin': return dataclasses.replace( self, underlying=init_substitute_expr( underlying=self.underlying, variables=variable, - substitutions=substitutions, + values=values, ), ) def subs( self, variable: tuple, - substitutions: tuple['ExpressionMixin', ...] = None, + values: tuple['ExpressionMixin', ...] = None, ) -> 'ExpressionMixin': return self.substitute( variable=variable, - substitutions=substitutions, + values=values, ) def sum(self): diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index 5083073..ec549a8 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -1,9 +1,5 @@ import abc -import collections -import math -import typing -import dataclass_abc from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -12,6 +8,7 @@ from polymatrix.expressionstate.expressionstate import ExpressionState class FilterExprMixin(ExpressionBaseMixin): + @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: @@ -45,6 +42,7 @@ class FilterExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): underlying_terms = underlying.get_poly(row, 0) + if underlying_terms is None: continue diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index cd94761..9bcabbe 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -7,10 +7,30 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class LinearInExprMixin(ExpressionBaseMixin): + """ + Maps a polynomial column vector + + underlying = [ + [1 + a x], + [x^2 ], + ] + + into a polynomial matrix + + output = [ + [1, a, 0], + [0, 0, 1], + ], + + where each column corresponds to a monomial defined by + + monomials = [1, x, x^2]. + """ + @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: @@ -23,7 +43,7 @@ class LinearInExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: + def variables(self) -> ExpressionBaseMixin: ... @property @@ -39,7 +59,7 @@ class LinearInExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state=state) state, monomials = get_monomial_indices(state, self.monomials) - state, variable_indices = get_variable_indices(state, self.variables) + state, variable_indices = get_variable_indices_from_variable(state, self.variables) assert underlying.shape[1] == 1 diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index d013722..f5bd18f 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class LinearMatrixInExprMixin(ExpressionBaseMixin): @@ -27,7 +27,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_index = get_variable_indices(state, variables=self.variable) + state, variable_index = get_variable_indices_from_variable(state, variables=self.variable) assert len(variable_index) == 1 diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index b309ccf..350ac6d 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -6,11 +6,28 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.expression.utils.sortmonomials import sort_monomials class LinearMonomialsExprMixin(ExpressionBaseMixin): + """ + Maps a polynomial matrix + + underlying = [ + [1, a x ], + [x^2, x + x^2], + ] + + into a vector of monomials + + output = [1, x, x^2] + + in variable + + variables = [x]. + """ + @property @abc.abstractclassmethod def underlying(self) -> ExpressionBaseMixin: @@ -18,7 +35,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: + def variables(self) -> ExpressionBaseMixin: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -28,7 +45,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variables) + state, variable_indices = get_variable_indices_from_variable(state, self.variables) def gen_linear_monomials(): for row in range(underlying.shape[0]): diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 130ab79..5c32f45 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices @@ -35,7 +35,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state=state) state, sos_monomials = get_monomial_indices(state, self.monomials) - state, variable_indices = get_variable_indices(state, self.variables) + state, variable_indices = get_variable_indices_from_variable(state, self.variables) assert underlying.shape == (1, 1), f'underlying shape is {underlying.shape}' diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index bddc321..81fbcb9 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -6,11 +6,28 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices class QuadraticMonomialsExprMixin(ExpressionBaseMixin): + """ + Maps a polynomial matrix + + underlying = [ + [x y ], + [x + x^2], + ] + + into a vector of monomials + + output = [1, x, y] + + in variable + + variables = [x, y]. + """ + @property @abc.abstractclassmethod def underlying(self) -> ExpressionBaseMixin: @@ -18,7 +35,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: + def variables(self) -> ExpressionBaseMixin: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -28,7 +45,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variables) + state, variable_indices = get_variable_indices_from_variable(state, self.variables) def gen_sos_monomials(): for row in range(underlying.shape[0]): diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index 9741897..6c0beae 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -3,12 +3,13 @@ import abc import collections import itertools import math +import typing from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial @@ -20,12 +21,7 @@ class SubstituteExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: - ... - - @property - @abc.abstractmethod - def substitutions(self) -> tuple[ExpressionBaseMixin, ...]: + def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -34,33 +30,40 @@ class SubstituteExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variables) - def acc_substitutions(acc, substitution_expr): - state, result = acc + def acc_substitutions(acc, next): + state, acc_variable, acc_substitution = acc + variable, expr = next - # for expr in self.expressions: - if isinstance(substitution_expr, ExpressionBaseMixin): - state, substitution = substitution_expr.apply(state) + state, indices = get_variable_indices_from_variable(state, variable) - assert substitution.shape == (1, 1), f'{substitution=}' + if indices is None: + return acc - polynomial = substitution.get_poly(0, 0) + state, substitution = expr.apply(state) - elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float): - polynomial = {tuple(): substitution_expr} + assert substitution.shape[1] == 1, f'{substitution=}' - else: - raise Exception(f'{substitution_expr=} not recognized') + def gen_polynomials(): + for row in range(substitution.shape[0]): + yield substitution.get_poly(row, 0) - return state, result + (polynomial,) + polynomials = tuple(gen_polynomials()) - *_, (state, substitutions) = tuple(itertools.accumulate( + return state, acc_variable + indices, acc_substitution + polynomials + + *_, (state, variable_indices, substitutions) = tuple(itertools.accumulate( self.substitutions, acc_substitutions, - initial=(state, tuple()), + initial=(state, tuple(), tuple()), )) + if len(substitutions) == 1: + substitutions = tuple(substitutions[0] for _ in variable_indices) + + else: + assert len(variable_indices) == len(substitutions), f'{substitutions=}' + terms = {} for row in range(underlying.shape[0]): @@ -81,7 +84,6 @@ class SubstituteExprMixin(ExpressionBaseMixin): index = variable_indices.index(variable) substitution = substitutions[index] - for _ in range(count): next = {} diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index e144ae5..cadc856 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -5,7 +5,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable class TruncateExprMixin(ExpressionBaseMixin): @@ -35,7 +35,7 @@ class TruncateExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variables) + state, variable_indices = get_variable_indices_from_variable(state, self.variables) terms = {} diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index aaf5b0b..371894e 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,40 +1,104 @@ +import itertools +import typing from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -def get_variable_indices(state, variables): +def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple[int]]: + + if isinstance(variable, ExpressionBaseMixin): + state, variable_polynomial = variable.apply(state) + + assert variable_polynomial.shape[1] == 1 + + def gen_variables_indices(): + + for row in range(variable_polynomial.shape[0]): + row_terms = variable_polynomial.get_poly(row, 0) + + assert len(row_terms) == 1, f'{row_terms} contains more than one term' + + for monomial in row_terms.keys(): + assert len(monomial) <= 1, f'{monomial=} contains more than one variable' + + if len(monomial) == 0: + continue + + assert monomial[0][1] == 1, f'{monomial[0]=}' + yield monomial[0][0] + + variable_indices = tuple(gen_variables_indices()) + + elif isinstance(variable, int): + variable_indices = (variable,) + + elif variable in state.offset_dict: + variable_indices = (state.offset_dict[variable][0],) - global_state = [state] + else: + variable_indices = None + # raise Exception(f'variable index not found for {variable=}, {state.offset_dict=}') + + return state, variable_indices + + +def get_variable_indices(state, variables): if not isinstance(variables, tuple): variables = (variables,) - def gen_indices(): - for variable in variables: - if isinstance(variable, ExpressionBaseMixin): - global_state[0], variable_polynomial = variable.apply(global_state[0]) + # assert isinstance(variables, tuple), f'{variables=}' + + def acc_variable_indices(acc, variable): + state, indices = acc + + state, new_indices = get_variable_indices_from_variable(state, variable) + + return state, indices + new_indices + + *_, (state, indices) = itertools.accumulate( + variables, + acc_variable_indices, + initial=(state, tuple()), + ) + + return state, indices + + + # global_state = [state] + + # assert isinstance(variables, tuple), f'{variables=}' + + # # if not isinstance(variables, tuple): + # # variables = (variables,) + + # def gen_indices(): + # for variable in variables: + # if isinstance(variable, ExpressionBaseMixin): + # global_state[0], variable_polynomial = variable.apply(global_state[0]) - assert variable_polynomial.shape[1] == 1 + # assert variable_polynomial.shape[1] == 1 - for row in range(variable_polynomial.shape[0]): - row_terms = variable_polynomial.get_poly(row, 0) + # for row in range(variable_polynomial.shape[0]): + # row_terms = variable_polynomial.get_poly(row, 0) - assert len(row_terms) == 1, f'{row_terms} contains more than one term' + # assert len(row_terms) == 1, f'{row_terms} contains more than one term' - for monomial in row_terms.keys(): - assert len(monomial) <= 1, f'{monomial=} contains more than one variable' + # for monomial in row_terms.keys(): + # assert len(monomial) <= 1, f'{monomial=} contains more than one variable' - if len(monomial) == 0: - continue + # if len(monomial) == 0: + # continue - assert monomial[0][1] == 1, f'{monomial[0]=}' - yield monomial[0][0] + # assert monomial[0][1] == 1, f'{monomial[0]=}' + # yield monomial[0][0] - elif isinstance(variable, int): - yield variable + # elif isinstance(variable, int): + # yield variable - else: - yield global_state[0].offset_dict[variable][0] + # # else: + # elif variable in global_state[0].offset_dict: + # yield global_state[0].offset_dict[variable][0] - indices = tuple(gen_indices()) + # indices = tuple(gen_indices()) - return global_state[0], indices + # return global_state[0], indices diff --git a/polymatrix/expression/utils/mergemonomialindices.py b/polymatrix/expression/utils/mergemonomialindices.py index 7f8ba15..c572852 100644 --- a/polymatrix/expression/utils/mergemonomialindices.py +++ b/polymatrix/expression/utils/mergemonomialindices.py @@ -19,9 +19,4 @@ def merge_monomial_indices(monomials): else: m1_dict[index] = count - # return tuple(sorted( - # m1_dict.items(), - # key=lambda m: m[0], - # )) - return sort_monomial_indices(m1_dict.items()) diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py index e08e1eb..18dba05 100644 --- a/polymatrix/expressionstate/mixins/expressionstatemixin.py +++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py @@ -16,7 +16,7 @@ class ExpressionStateMixin( @abc.abstractmethod def n_param(self) -> int: """ - current number of parameters used in polynomial matrix expressions + number of parameters used in polynomial matrix expressions """ ... @@ -24,11 +24,16 @@ class ExpressionStateMixin( @property @abc.abstractmethod def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]: + """ + a variable consists of one or more parameters indexed by a start + and an end index + """ + ... @property @abc.abstractmethod - def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: + def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: ... def get_key_from_offset(self, offset: int): |