diff options
Diffstat (limited to '')
43 files changed, 711 insertions, 326 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 78b4afa..139e391 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -1,11 +1,22 @@ +import itertools +import numpy as np +import scipy.sparse +# import polymatrix.statemonad + from polymatrix.expression.expression import Expression +from polymatrix.expression.expressionstate import ExpressionState from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr from polymatrix.expression.init.initexpression import init_expression from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr from polymatrix.expression.init.initkktexpr import init_kkt_expr +from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr from polymatrix.expression.init.initvstackexpr import init_v_stack_expr from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.utils.getvariableindices import get_variable_indices +from polymatrix.statemonad.init.initstatemonad import init_state_monad +from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin +from polymatrix.utils import monomial_to_index def from_( @@ -66,4 +77,249 @@ def kkt( ) ) - + # self_cost = cost + # self_variables = variables + # self_equality = equality + + # def func(state: ExpressionState): + # state, cost = self_cost.apply(state=state) + + # assert cost.shape[1] == 1 + + # if self_equality is not None: + + # state, equality = self_equality.apply(state=state) + + # state, equality_der = self_equality.diff( + # self_variables, + # introduce_derivatives=True, + # ).apply(state) + + # assert cost.shape[0] == equality_der.shape[1] + + # def acc_nu_variables(acc, v): + # state, nu_variables = acc + + # nu_variable = state.n_param + # state = state.register(n_param=1) + + # return state, nu_variables + [nu_variable] + + # *_, (state, nu_variables) = tuple(itertools.accumulate( + # range(equality.shape[0]), + # acc_nu_variables, + # initial=(state, []), + # )) + + # else: + # nu_variables = tuple() + + # idx_start = 0 + + # terms = {} + + # for row in range(cost.shape[0]): + # try: + # monomial_terms = cost.get_poly(row, 0) + # except KeyError: + # monomial_terms = {} + + # for eq_idx, nu_variable in enumerate(nu_variables): + + # try: + # underlying_terms = equality_der.get_poly(eq_idx, row) + # except KeyError: + # continue + + # for monomial, value in underlying_terms.items(): + # new_monomial = monomial + (nu_variable,) + + # if new_monomial not in monomial_terms: + # monomial_terms[new_monomial] = 0 + + # monomial_terms[new_monomial] += value + + # terms[idx_start, 0] = monomial_terms + # idx_start += 1 + + # cost_expr = init_expression(init_from_terms_expr( + # terms=terms, + # shape=(idx_start, 1), + # )) + + # terms = {} + # for eq_idx, nu_variable in enumerate(nu_variables): + # terms[eq_idx, 0] = {(nu_variable,): 1} + + # nu_expr = init_expression(init_from_terms_expr( + # terms=terms, + # shape=(len(nu_variables), 1), + # )) + + # return state, (cost_expr, nu_expr) + + # return StateMonadMixin.init(func) + +# def to_linear_matrix( +# expr: Expression, +# variables: Expression, +# ) -> StateMonad[ExpressionState, tuple[Expression, ...]]: +# def func(state: ExpressionState): +# state, variable_indices = get_variable_indices(state, variables) + +# def gen_matrices(): +# for variable_index in variable_indices: +# yield init_linear_matrix_in_expr( +# underlying=expr, +# variable=variable_index, +# ) + +# matrices = tuple(gen_matrices()) + +# return state, matrices + +# return StateMonad.init(func) + + +def to_matrix_equations( + expr: Expression, +) -> StateMonadMixin[ExpressionState, tuple[tuple[np.ndarray, ...], tuple[int, ...]]]: + def func(state: ExpressionState): + state, underlying = expr.apply(state) + + assert underlying.shape[1] == 1 + + def gen_used_variables(): + def gen_used_auxillary_variables(considered): + monomial_terms = state.auxillary_equations[considered[-1]] + for monomial in monomial_terms.keys(): + for variable in monomial: + yield variable + + if variable not in considered and variable in state.auxillary_equations: + yield from gen_used_auxillary_variables(considered + (variable,)) + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + for monomial in underlying_terms.keys(): + for variable in monomial: + yield variable + + if variable in state.auxillary_equations: + yield from gen_used_auxillary_variables((variable,)) + + used_variables = set(gen_used_variables()) + + ordered_variable_index = tuple(sorted(used_variables)) + variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} + + n_param = len(ordered_variable_index) + + A = np.zeros((n_param, 1), dtype=np.float32) + B = np.zeros((n_param, n_param), dtype=np.float32) + C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32) + + def populate_matrices(monomial_terms, row): + for monomial, value in monomial_terms.items(): + new_monomial = tuple(variable_index_map[var] for var in monomial) + col = monomial_to_index(n_param, new_monomial) + + match len(new_monomial): + case 0: + A[row, col] = value + case 1: + B[row, col] = value + case 2: + C[row, col] = value + case _: + raise Exception(f'illegal case {new_monomial=}') + + for row in range(underlying.shape[0]): + try: + underlying_terms = underlying.get_poly(row, 0) + except KeyError: + continue + + populate_matrices( + monomial_terms=underlying_terms, + row=row, + ) + + current_row = underlying.shape[0] + + for key, monomial_terms in state.auxillary_equations.items(): + if key in ordered_variable_index: + populate_matrices( + monomial_terms=monomial_terms, + row=current_row, + ) + current_row += 1 + + # assert current_row == n_param, f'{current_row} is not {n_param}' + + return state, ((A, B, C), ordered_variable_index) + + return StateMonadMixin.init(func) + +def to_constant_matrix( + expr: Expression, +) -> StateMonadMixin[ExpressionState, np.ndarray]: + + def func(state: ExpressionState): + state, underlying = expr.apply(state) + + A = np.zeros(underlying.shape, dtype=np.float32) + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + for monomial, value in underlying_terms.items(): + + if len(monomial) == 0: + A[row, col] = value + + return state, A + + return init_state_monad(func) + + +def rows( + expr: Expression, +) -> StateMonadMixin[ExpressionState, np.ndarray]: + + def func(state: ExpressionState): + state, underlying = expr.apply(state) + + def gen_row_terms(): + for row in range(underlying.shape[0]): + + terms = {} + + for col in range(underlying.shape[1]): + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + terms[0, col] = underlying_terms + + yield init_expression(underlying=init_from_terms_expr( + terms=terms, + shape=(1, underlying.shape[1]) + )) + + row_terms = tuple(gen_row_terms()) + + return state, row_terms + + return init_state_monad(func) diff --git a/polymatrix/expression/forallexpr.py b/polymatrix/expression/forallexpr.py deleted file mode 100644 index 972d553..0000000 --- a/polymatrix/expression/forallexpr.py +++ /dev/null @@ -1,4 +0,0 @@ -from polymatrix.expression.mixins.forallexprmixin import ForAllExprMixin - -class ForAllExpr(ForAllExprMixin): - pass diff --git a/polymatrix/expression/impl/forallexprimpl.py b/polymatrix/expression/impl/linearinexprimpl.py index 0960479..d97051f 100644 --- a/polymatrix/expression/impl/forallexprimpl.py +++ b/polymatrix/expression/impl/linearinexprimpl.py @@ -1,9 +1,9 @@ import dataclass_abc -from polymatrix.expression.forallexpr import ForAllExpr +from polymatrix.expression.linearinexpr import LinearInExpr from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @dataclass_abc.dataclass_abc(frozen=True) -class ForAllExprImpl(ForAllExpr): +class LinearInExprImpl(LinearInExpr): underlying: ExpressionBaseMixin variables: tuple diff --git a/polymatrix/expression/impl/linearmatrixinexprimpl.py b/polymatrix/expression/impl/linearmatrixinexprimpl.py new file mode 100644 index 0000000..b40e6a6 --- /dev/null +++ b/polymatrix/expression/impl/linearmatrixinexprimpl.py @@ -0,0 +1,9 @@ +import dataclass_abc +from polymatrix.expression.linearmatrixinexpr import LinearMatrixInExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class LinearMatrixInExprImpl(LinearMatrixInExpr): + underlying: ExpressionBaseMixin + variable: int diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 9d0fc0d..525a697 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -5,14 +5,14 @@ from polymatrix.expression.impl.evalexprimpl import EvalExprImpl def init_eval_expr( underlying: ExpressionBaseMixin, - values: tuple, - variables: tuple = None, + variables: tuple, + values: tuple = None, ): - if variables is None: - assert isinstance(values, tuple) + if values is None: + assert isinstance(variables, tuple) - variables, values = tuple(zip(*values)) + variables, values = tuple(zip(*variables)) elif isinstance(values, np.ndarray): values = tuple(values) @@ -20,8 +20,8 @@ def init_eval_expr( elif not isinstance(values, tuple): values = (values,) - if not isinstance(variables, tuple): - variables = (variables,) + # if not isinstance(variables, tuple): + # variables = (variables,) return EvalExprImpl( underlying=underlying, diff --git a/polymatrix/expression/init/initfromarrayexpr.py b/polymatrix/expression/init/initfromarrayexpr.py index 6aab26c..e6f57c8 100644 --- a/polymatrix/expression/init/initfromarrayexpr.py +++ b/polymatrix/expression/init/initfromarrayexpr.py @@ -21,7 +21,7 @@ def init_from_array_expr( assert all(len(col) == n_col for col in data) case _: - data = (data,) + data = tuple((e,) for e in data) case _: data = ((data,),) diff --git a/polymatrix/expression/init/initforallexpr.py b/polymatrix/expression/init/initlinearinexpr.py index 84388d2..f7f76e4 100644 --- a/polymatrix/expression/init/initforallexpr.py +++ b/polymatrix/expression/init/initlinearinexpr.py @@ -1,12 +1,12 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.impl.forallexprimpl import ForAllExprImpl +from polymatrix.expression.impl.linearinexprimpl import LinearInExprImpl -def init_for_all_expr( +def init_linear_in_expr( underlying: ExpressionBaseMixin, variables: tuple, ): - return ForAllExprImpl( + return LinearInExprImpl( underlying=underlying, variables=variables, ) diff --git a/polymatrix/expression/init/initlinearmatrixinexpr.py b/polymatrix/expression/init/initlinearmatrixinexpr.py new file mode 100644 index 0000000..cd4ce97 --- /dev/null +++ b/polymatrix/expression/init/initlinearmatrixinexpr.py @@ -0,0 +1,12 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.linearmatrixinexprimpl import LinearMatrixInExprImpl + + +def init_linear_matrix_in_expr( + underlying: ExpressionBaseMixin, + variable: int, +): + return LinearMatrixInExprImpl( + underlying=underlying, + variable=variable, +) diff --git a/polymatrix/expression/linearinexpr.py b/polymatrix/expression/linearinexpr.py new file mode 100644 index 0000000..4edf8b3 --- /dev/null +++ b/polymatrix/expression/linearinexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin + +class LinearInExpr(LinearInExprMixin): + pass diff --git a/polymatrix/expression/linearmatrixinexpr.py b/polymatrix/expression/linearmatrixinexpr.py new file mode 100644 index 0000000..2bce2e7 --- /dev/null +++ b/polymatrix/expression/linearmatrixinexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin + +class LinearMatrixInExpr(LinearMatrixInExprMixin): + pass diff --git a/polymatrix/expression/mixins/accumulateexprmixin.py b/polymatrix/expression/mixins/accumulateexprmixin.py index 1f834bf..76e1717 100644 --- a/polymatrix/expression/mixins/accumulateexprmixin.py +++ b/polymatrix/expression/mixins/accumulateexprmixin.py @@ -24,7 +24,7 @@ class AccumulateExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index e068ba1..d274f2a 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -26,14 +26,14 @@ class AdditionExprMixin(ExpressionBaseMixin): # return self.left.shape # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - assert left.shape == right.shape + assert left.shape == right.shape, f'{left.shape} != {right.shape}' terms = {} diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 43c0efe..a551c1f 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -4,8 +4,6 @@ import collections import dataclasses import itertools import typing -import dataclass_abc -from numpy import nonzero, var from polymatrix.expression.init.initderivativekey import init_derivative_key from polymatrix.expression.init.initpolymatrix import init_poly_matrix @@ -31,51 +29,15 @@ class DerivativeExprMixin(ExpressionBaseMixin): def introduce_derivatives(self) -> bool: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # match self.variables: - # case ExpressionBaseMixin(): - # n_cols = self.variables.shape[0] - - # case _: - # n_cols = len(self.variables) - - # return self.underlying.shape[0], n_cols - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - # match self.variables: - # case ExpressionBaseMixin(): - # state, variables = self.variables.apply(state) - - # assert variables.shape[1] == 1 - - # def gen_indices(): - # for row in range(variables.shape[0]): - # row_terms = variables.get_poly(row, 0) - - # assert len(row_terms) == 1 - - # for monomial in row_terms.keys(): - # assert len(monomial) == 1 - # yield monomial[0] - - # diff_wrt_variables = tuple(gen_indices()) - - # case _: - # def gen_indices(): - # for variable in self.variables: - # if variable in state.offset_dict: - # yield state.offset_dict[variable][0] - - # diff_wrt_variables = tuple(gen_indices()) + state, underlying = self.underlying.apply(state=state) - state, diff_wrt_variables = get_variable_indices(self.variables, state) + state, diff_wrt_variables = get_variable_indices(state, self.variables) def get_derivative_terms( monomial_terms, @@ -181,8 +143,6 @@ class DerivativeExprMixin(ExpressionBaseMixin): return state, dict(derivation_terms) - state, underlying = self.underlying.apply(state=state) - terms = {} for row in range(underlying.shape[0]): diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py index f7f76eb..5b69d3f 100644 --- a/polymatrix/expression/mixins/determinantexprmixin.py +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -22,7 +22,7 @@ class DeterminantExprMixin(ExpressionBaseMixin): # return self.underlying.shape[0], 1 # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index 7b5e0c5..4b25ec6 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -25,7 +25,7 @@ class DivisionExprMixin(ExpressionBaseMixin): # return self.left.shape # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 931e022..ca5d921 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -25,7 +25,7 @@ class ElemMultExprMixin(ExpressionBaseMixin): # return self.left.shape # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index e61cebd..4d78b26 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -26,15 +26,21 @@ class EvalExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(self.variables, state) + state, variable_indices = get_variable_indices(state, self.variables) - assert len(variable_indices) == len(self.values) + if len(self.values) == 1: + values = tuple(self.values[0] for _ in variable_indices) + + else: + assert len(variable_indices) == len(self.values) + + values = self.values terms = {} @@ -55,7 +61,7 @@ class EvalExprMixin(ExpressionBaseMixin): if variable in variable_indices: index = variable_indices.index(variable) - new_value = value * self.values[index] + new_value = value * values[index] return monomial, new_value else: diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index 43b3fdb..baf605d 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -1,17 +1,20 @@ import abc from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin +from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin class ExpressionBaseMixin( - abc.ABC, + # StateMonad[ExpressionStateMixin, PolyMatrixMixin], + abc.ABC ): - # @property - # @abc.abstractclassmethod - # def shape(self) -> tuple[int, int]: - # ... + @abc.abstractmethod + def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + ... @abc.abstractmethod def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - ... + assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}' + + return self._apply(state) diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index b43b3be..aa9ca17 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -10,9 +10,10 @@ from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr from polymatrix.expression.init.initdivisionexpr import init_division_expr from polymatrix.expression.init.initelemmultexpr import init_elem_mult_expr from polymatrix.expression.init.initevalexpr import init_eval_expr -from polymatrix.expression.init.initforallexpr import init_for_all_expr +from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr from polymatrix.expression.init.initgetitemexpr import init_get_item_expr +from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize_terms_expr from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr @@ -188,12 +189,21 @@ class ExpressionMixin( def linear_in(self, variables: tuple) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_for_all_expr( + underlying=init_linear_in_expr( underlying=self.underlying, variables=variables, ), ) + def linear_matrix_in(self, variable) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_linear_matrix_in_expr( + underlying=self.underlying, + variable=variable, + ), + ) + def quadratic_in(self, variables: tuple) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -213,15 +223,15 @@ class ExpressionMixin( def eval( self, - values: tuple[float, ...], - variables: tuple = None, + variable: tuple, + value: tuple[float, ...] = None, ) -> 'ExpressionMixin': return dataclasses.replace( self, underlying=init_eval_expr( underlying=self.underlying, - variables=variables, - values=values, + variables=variable, + values=value, ), ) diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py index 86f8b72..e79810c 100644 --- a/polymatrix/expression/mixins/fromarrayexprmixin.py +++ b/polymatrix/expression/mixins/fromarrayexprmixin.py @@ -15,13 +15,8 @@ class FromArrayExprMixin(ExpressionBaseMixin): def data(self) -> tuple[tuple[float]]: pass - # # overwrites abstract method of `PolyMatrixExprBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return len(self.data), len(self.data[0]) - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 4ada8ec..1af3f11 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -21,7 +21,7 @@ class FromTermsExprMixin(ExpressionBaseMixin): pass # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index 26ed728..c8ac02e 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -22,13 +22,8 @@ class GetItemExprMixin(ExpressionBaseMixin): def index(self) -> tuple[int, int]: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return 1, 1 - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py index b4c9e72..20949a1 100644 --- a/polymatrix/expression/mixins/kktexprmixin.py +++ b/polymatrix/expression/mixins/kktexprmixin.py @@ -39,7 +39,7 @@ class KKTExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: @@ -170,7 +170,7 @@ class KKTExprMixin(ExpressionBaseMixin): except KeyError: continue - # f(x) <= 0 + # f(x) <= -0.01 terms[idx_start, 0] = underlying_terms | {(r_inequality, r_inequality): 1} idx_start += 1 diff --git a/polymatrix/expression/mixins/forallexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index a988b24..557bed0 100644 --- a/polymatrix/expression/mixins/forallexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState -class ForAllExprMixin(ExpressionBaseMixin): +class LinearInExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: @@ -20,18 +20,15 @@ class ForAllExprMixin(ExpressionBaseMixin): def variables(self) -> tuple: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.underlying.shape - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) + # todo: uncomment this + # state, variable_indices = get_variable_indices(state, self.variables) variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) terms = {} diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py new file mode 100644 index 0000000..321d955 --- /dev/null +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -0,0 +1,64 @@ + +import abc +import collections +from numpy import var + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices + + +class LinearMatrixInExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def variable(self) -> int: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def _apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + state, variable_indices = get_variable_indices(state, self.variable) + + assert len(variable_indices) == 1 + + terms = {} + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + monomial_terms = collections.defaultdict(float) + + for monomial, value in underlying_terms.items(): + + x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + + # only take linear terms + if len(x_monomial) == 1: + p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + + monomial_terms[p_monomial] += value + + terms[row, col] = dict(monomial_terms) + + poly_matrix = init_poly_matrix( + terms=terms, + shape=underlying.shape, + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index d1c96d4..f1e1720 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -19,13 +19,8 @@ class MatrixMultExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return (self.left.shape[0], self.right.shape[1]) - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py index 4d46ece..f20306c 100644 --- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py +++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py @@ -32,46 +32,46 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): # def shape(self) -> tuple[int, int]: # return self.underlying.shape - @dataclass_abc.dataclass_abc - class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin): - shape: tuple[int, int] - terms: dict - start_index: int - n_param: int + # @dataclass_abc.dataclass_abc + # class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin): + # shape: tuple[int, int] + # terms: dict + # start_index: int + # n_param: int - @property - def param(self) -> tuple[int, int]: - outer_self = self + # @property + # def param(self) -> tuple[int, int]: + # outer_self = self - @dataclass_abc.dataclass_abc(frozen=True) - class ParameterExpr(ExpressionBaseMixin): - def apply( - self, - state: ExpressionStateMixin, - ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + # @dataclass_abc.dataclass_abc(frozen=True) + # class ParameterExpr(ExpressionBaseMixin): + # def apply( + # self, + # state: ExpressionStateMixin, + # ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - state, poly_matrix = outer_self.apply(state) + # state, poly_matrix = outer_self.apply(state) - n_param = poly_matrix.n_param - start_index = poly_matrix.start_index + # n_param = poly_matrix.n_param + # start_index = poly_matrix.start_index - def gen_monomials(): - for rel_index in range(n_param): - yield {(start_index + rel_index,): 1} + # def gen_monomials(): + # for rel_index in range(n_param): + # yield {(start_index + rel_index,): 1} - terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())} + # terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())} - poly_matrix = init_poly_matrix( - terms=terms, - shape=(n_param, 1), - ) + # poly_matrix = init_poly_matrix( + # terms=terms, + # shape=(n_param, 1), + # ) - return state, poly_matrix + # return state, poly_matrix - return ParameterExpr() + # return ParameterExpr() # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: @@ -87,8 +87,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) - idx_start = state.n_param - n_param = 0 + start_index = state.n_param terms = {} for row in range(underlying.shape[0]): @@ -99,32 +98,42 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): except KeyError: continue - terms_row_col = {} - collected_terms = [] + def gen_x_monomial_terms(): + for monomial, value in underlying_terms.items(): + x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + yield monomial, x_monomial, value - for monomial, value in underlying_terms.items(): + x_monomial_terms = tuple(gen_x_monomial_terms()) - x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + collected_terms = tuple(sorted(set((x_monomial for _, x_monomial, _ in x_monomial_terms)))) - if x_monomial not in collected_terms: - collected_terms.append(x_monomial) - - idx = idx_start + n_param + collected_terms.index(x_monomial) + terms_row_col = {} - new_monomial = monomial + (idx,) + for monomial, x_monomial, value in x_monomial_terms: + + new_monomial = monomial + (start_index + collected_terms.index(x_monomial),) terms_row_col[new_monomial] = value - n_param += len(collected_terms) terms[row, col] = terms_row_col - state = state.register(key=self, n_param=n_param) + start_index += len(collected_terms) + + state = state.register( + key=self, + n_param=start_index - state.n_param, + ) + + # poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix( + # terms=terms, + # shape=underlying.shape, + # start_index=idx_start, + # n_param=n_param, + # ) - poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix( + poly_matrix = init_poly_matrix( terms=terms, shape=underlying.shape, - start_index=idx_start, - n_param=n_param, ) state = dataclasses.replace( diff --git a/polymatrix/expression/mixins/polymatrixmixin.py b/polymatrix/expression/mixins/polymatrixmixin.py index c0dcac2..f83a62c 100644 --- a/polymatrix/expression/mixins/polymatrixmixin.py +++ b/polymatrix/expression/mixins/polymatrixmixin.py @@ -18,84 +18,84 @@ class PolyMatrixMixin(abc.ABC): def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: ... - def get_equations( - self, - state: ExpressionStateMixin, - ): - assert self.shape[1] == 1 - - def gen_used_variables(): - def gen_used_auxillary_variables(considered): - monomial_terms = state.auxillary_equations[considered[-1]] - for monomial in monomial_terms.keys(): - for variable in monomial: - yield variable - - if variable not in considered and variable in state.auxillary_equations: - yield from gen_used_auxillary_variables(considered + (variable,)) - - for row in range(self.shape[0]): - for col in range(self.shape[1]): - - try: - underlying_terms = self.get_poly(row, col) - except KeyError: - continue - - for monomial in underlying_terms.keys(): - for variable in monomial: - yield variable - - if variable in state.auxillary_equations: - yield from gen_used_auxillary_variables((variable,)) - - used_variables = set(gen_used_variables()) - - ordered_variable_index = tuple(sorted(used_variables)) - variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} - - n_param = len(ordered_variable_index) - - A = np.zeros((n_param, 1), dtype=np.float32) - B = np.zeros((n_param, n_param), dtype=np.float32) - C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32) - - def populate_matrices(monomial_terms, row): - for monomial, value in monomial_terms.items(): - new_monomial = tuple(variable_index_map[var] for var in monomial) - col = monomial_to_index(n_param, new_monomial) - - match len(new_monomial): - case 0: - A[row, col] = value - case 1: - B[row, col] = value - case 2: - C[row, col] = value - case _: - raise Exception(f'illegal case {new_monomial=}') - - for row in range(self.shape[0]): - try: - underlying_terms = self.get_poly(row, 0) - except KeyError: - continue - - populate_matrices( - monomial_terms=underlying_terms, - row=row, - ) - - current_row = self.shape[0] - - for key, monomial_terms in state.auxillary_equations.items(): - if key in ordered_variable_index: - populate_matrices( - monomial_terms=monomial_terms, - row=current_row, - ) - current_row += 1 - - # assert current_row == n_param, f'{current_row} is not {n_param}' - - return (A, B, C), ordered_variable_index + # def get_equations( + # self, + # state: ExpressionStateMixin, + # ): + # assert self.shape[1] == 1 + + # def gen_used_variables(): + # def gen_used_auxillary_variables(considered): + # monomial_terms = state.auxillary_equations[considered[-1]] + # for monomial in monomial_terms.keys(): + # for variable in monomial: + # yield variable + + # if variable not in considered and variable in state.auxillary_equations: + # yield from gen_used_auxillary_variables(considered + (variable,)) + + # for row in range(self.shape[0]): + # for col in range(self.shape[1]): + + # try: + # underlying_terms = self.get_poly(row, col) + # except KeyError: + # continue + + # for monomial in underlying_terms.keys(): + # for variable in monomial: + # yield variable + + # if variable in state.auxillary_equations: + # yield from gen_used_auxillary_variables((variable,)) + + # used_variables = set(gen_used_variables()) + + # ordered_variable_index = tuple(sorted(used_variables)) + # variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} + + # n_param = len(ordered_variable_index) + + # A = np.zeros((n_param, 1), dtype=np.float32) + # B = np.zeros((n_param, n_param), dtype=np.float32) + # C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32) + + # def populate_matrices(monomial_terms, row): + # for monomial, value in monomial_terms.items(): + # new_monomial = tuple(variable_index_map[var] for var in monomial) + # col = monomial_to_index(n_param, new_monomial) + + # match len(new_monomial): + # case 0: + # A[row, col] = value + # case 1: + # B[row, col] = value + # case 2: + # C[row, col] = value + # case _: + # raise Exception(f'illegal case {new_monomial=}') + + # for row in range(self.shape[0]): + # try: + # underlying_terms = self.get_poly(row, 0) + # except KeyError: + # continue + + # populate_matrices( + # monomial_terms=underlying_terms, + # row=row, + # ) + + # current_row = self.shape[0] + + # for key, monomial_terms in state.auxillary_equations.items(): + # if key in ordered_variable_index: + # populate_matrices( + # monomial_terms=monomial_terms, + # row=current_row, + # ) + # current_row += 1 + + # # assert current_row == n_param, f'{current_row} is not {n_param}' + + # return (A, B, C), ordered_variable_index diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 5aeef1b..ce6b5c2 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -24,7 +24,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): # return 2*(len(self.variables),) # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: @@ -49,7 +49,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) - assert len(x_monomial) == 2 + assert len(x_monomial) == 2, f'{x_monomial} should be of length 2' assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' key = tuple(reversed(x_monomial)) diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index 26777e9..5d5c00f 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -16,12 +16,8 @@ class RepMatExprMixin(ExpressionBaseMixin): def repetition(self) -> tuple[int, int]: ... - # @property - # def shape(self) -> tuple[int, int]: - # return tuple(s*r for s, r in zip(self.underlying.shape, self.repetition)) - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py index 3079e0f..0aac30a 100644 --- a/polymatrix/expression/mixins/toquadraticexprmixin.py +++ b/polymatrix/expression/mixins/toquadraticexprmixin.py @@ -15,13 +15,8 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.underlying.shape - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index d24cf53..fa074d7 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -17,13 +17,8 @@ class TransposeExprMixin(ExpressionBaseMixin): def underlying(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `PolyMatrixExprBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.underlying.shape[1], self.underlying.shape[0] - # overwrites abstract method of `PolyMatrixExprBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index c192ae3..90cfb20 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -15,14 +15,8 @@ class VStackExprMixin(ExpressionBaseMixin): def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # n_row = sum(expr.shape[0] for expr in self.underlying) - # return n_row, self.underlying[0].shape[1] - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 9379e5f..46bb51b 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,7 +1,9 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -def get_variable_indices(variables, state): +def get_variable_indices(state, variables): + + # print(f'{variables=}') if isinstance(variables, ExpressionBaseMixin): state, variables = variables.apply(state) @@ -20,18 +22,20 @@ def get_variable_indices(variables, state): return state, tuple(gen_indices()) - if not isinstance(variables, tuple): - variables = (variables,) + else: + + if not isinstance(variables, tuple): + variables = (variables,) - # assert all(isinstance(variable, type(variables[0])) for variable in variables) + # assert all(isinstance(variable, type(variables[0])) for variable in variables) - def gen_indices(): - for variable in variables: + def gen_indices(): + for variable in variables: - if isinstance(variable, int): - yield variable + if isinstance(variable, int): + yield variable - else: - yield state.offset_dict[variable][0] + else: + yield state.offset_dict[variable][0] - return state, tuple(gen_indices()) + return state, tuple(gen_indices()) diff --git a/polymatrix/polysolver.py b/polymatrix/polysolver.py index da6fd5d..8c055f8 100644 --- a/polymatrix/polysolver.py +++ b/polymatrix/polysolver.py @@ -153,10 +153,11 @@ def solve_poly_system(data, m): else: b_num_inv = scipy.sparse.linalg.inv(data[1]) - n = b_num_inv.shape[0] - p = b_num_inv @ np.ones((n, 1)) + n_var = b_num_inv.shape[0] + p0 = b_num_inv @ np.ones((n_var, 1)) - assert (data[0] + np.ones((n, 1)) < 0.01).all(), f'{data[0]=}, {data[0] + np.ones((n, 1))=}' + assert (data[0] + np.ones((n_var, 1)) < 0.01).all(), f'{data[0]=}, {data[0] + np.ones((n_var, 1))=}' + assert (data[0] + np.ones((n_var, 1)) < 0.01).all(), f'data[0] is not one' def func(acc, _): """ @@ -182,76 +183,73 @@ def solve_poly_system(data, m): n - (2*d-1)*idx, d-1, )) - - def acc_kron(l): - *_, last = itertools.accumulate(l, lambda acc, v: np.kron(acc, v)) - return last def gen_p(): for degree, d_data in data.items(): if 1 < degree and degree-1 <= k: indices_list = list(gen_indices(k-degree+1, degree-1)) - permutations = (perm for indices in indices_list for perm in more_itertools.distinct_permutations(indices)) + permutations = lambda: (perm for indices in indices_list for perm in more_itertools.distinct_permutations(indices)) - if not scipy.sparse.issparse(data): + if not scipy.sparse.issparse(d_data): def acc_kron(perm): *_, last = itertools.accumulate(((p[:,idx:idx+1] for idx in perm)), lambda acc, v: np.kron(acc, v)) return last - yield from (d_data @ acc_kron(perm) for perm in permutations) + yield from (d_data @ acc_kron(perm) for perm in permutations()) else: - csr_data = d_data.tocsr() - - n_eye = sum(1 for idx in indices if idx == 0) - n_col = n**n_eye - n_index = n_col * n**(len(indices) - n_eye) - def gen_array_per_permuatation(): + csr_data = d_data.tocsr() - def gen_coord(perm): - n_row = csr_data.shape[0] + n_row = csr_data.shape[0] + n_col = csr_data.shape[1] - def acc_col_idx_and_value(acc, v): - relindex, relrow, val = acc + def gen_row_values(): + + def acc_kron_operation(acc, v): + relindex, relrow, val = acc - n_relrow = relrow / n - grp_index = int(relindex / n_relrow) - n_relindex = int(relindex - grp_index * n_relrow) - n_val = val * p[grp_index, v:v+1] + n_relrow = relrow / n_var + grp_index = int(relindex / n_relrow) + n_relindex = int(relindex - grp_index * n_relrow) + n_val = val * float(p[grp_index, v:v+1]) - return n_relindex, n_relrow, n_val + return n_relindex, n_relrow, n_val - for row in range(n_row): + for row in range(n_row): - pt = slice(csr_data.indptr[row], csr_data.indptr[row+1]) + pt = slice(csr_data.indptr[row], csr_data.indptr[row+1]) - def gen_val_per_row(): - for inner_idx, array_val in zip(csr_data.indices[pt], csr_data.data[pt]): + def gen_row_multiplication(): + for col_idx, col_val in zip(csr_data.indices[pt], csr_data.data[pt]): + for perm in permutations(): - *_, last = itertools.accumulate( + *_, (_, _, val) = itertools.accumulate( perm, - acc_col_idx_and_value, - initial=(inner_idx, n_index, array_val), + acc_kron_operation, + initial=(col_idx, n_col, col_val), ) - _, _, val = last yield val - - yield sum(gen_val_per_row()) + + yield sum(gen_row_multiplication()) + + # for perm in permutations: - for perm in permutations: + row_values = tuple(gen_row_values()) - array_values = list(gen_coord(perm)) + assert len(row_values) == n_row - yield scipy.sparse.csr_array((array_values, np.zeros(len(array_values)), csr_data.indptr), shape=(csr_data.shape[0], n_col)) + yield scipy.sparse.coo_array( + (row_values, (range(n_row), np.zeros(n_row))), + shape=(n_row, 1), + ) return np.concatenate((p, -b_num_inv @ sum(gen_p())), axis=1) - *_, sol_subs = itertools.accumulate(range(m-1), func, initial=p) + *_, sol_subs = itertools.accumulate(range(m-1), func, initial=p0) - # return np.asarray(sol_subs) return sum(np.asarray(sol_subs).T) @@ -297,7 +295,7 @@ def inner_smart_solve(data, irange=None, idegree=None, a_init=None): return a, err - a, err = list(zip(*itertools.accumulate( + a, err = tuple(zip(*itertools.accumulate( irange, acc_func, initial=(a_init, 0), @@ -319,27 +317,23 @@ def outer_smart_solve(data, a_init=None, n_iter=10, a_max=1.0, irange=None, ideg try: a, err = inner_smart_solve(data, irange=irange, idegree=idegree, a_init=a_subs) - subs_data = substitude_x_add_a(data, a[-1]) - sol = solve_poly_system(subs_data, 6) + # subs_data = substitude_x_add_a(data, a[-1]) + # sol = solve_poly_system(subs_data, 6) - error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) + max(a[-1]) + # error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) + max(a[-1]) + # error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) # error_index = np.max(np.abs(eval_solution(subs_data, sol))) + max(a[-1]) # error_index = np.max(np.abs(eval_solution(subs_data, sol))) except: print(f'nan error, continue') - # print(f'nan error for {a_init=}, continue') - # yield np.nan, a_init, np.nan continue - print(f'{error_index=}') + # print(f'{error_index=}') - yield error_index, a, err - - # if error_index < 1.0: - # break + yield a, err - _, a, err = min(gen_a_err()) + a, err = tuple(zip(*gen_a_err())) return a, err @@ -359,6 +353,7 @@ def eval_solution(data, x=None): # yield d_data @ last if not scipy.sparse.issparse(d_data): + *_, last = itertools.accumulate(degree*(x,), lambda acc, v: np.kron(acc, v)) yield d_data @ last diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py new file mode 100644 index 0000000..33aa86a --- /dev/null +++ b/polymatrix/statemonad/__init__.py @@ -0,0 +1,16 @@ +from polymatrix.statemonad.init.initstatemonad import init_state_monad +from polymatrix.statemonad.statemonad import StateMonad + + +def zip(monads: tuple[StateMonad]): + + def zip_func(state): + values = tuple() + + for monad in monads: + state, val = monad.apply(state) + values += (val,) + + return state, values + + return init_state_monad(zip_func) diff --git a/polymatrix/statemonad/impl/__init__.py b/polymatrix/statemonad/impl/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/polymatrix/statemonad/impl/__init__.py diff --git a/polymatrix/statemonad/impl/statemonadimpl.py b/polymatrix/statemonad/impl/statemonadimpl.py new file mode 100644 index 0000000..5f8c6b8 --- /dev/null +++ b/polymatrix/statemonad/impl/statemonadimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.statemonad.statemonad import StateMonad + +from typing import Callable + +@dataclass_abc.dataclass_abc(frozen=True) +class StateMonadImpl(StateMonad): + apply_func: Callable diff --git a/polymatrix/statemonad/init/__init__.py b/polymatrix/statemonad/init/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/polymatrix/statemonad/init/__init__.py diff --git a/polymatrix/statemonad/init/initstatemonad.py b/polymatrix/statemonad/init/initstatemonad.py new file mode 100644 index 0000000..7d269f3 --- /dev/null +++ b/polymatrix/statemonad/init/initstatemonad.py @@ -0,0 +1,10 @@ +from typing import Callable +from polymatrix.statemonad.impl.statemonadimpl import StateMonadImpl + + +def init_state_monad( + apply_func: Callable, +): + return StateMonadImpl( + apply_func=apply_func, +) diff --git a/polymatrix/statemonad/mixins/__init__.py b/polymatrix/statemonad/mixins/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/polymatrix/statemonad/mixins/__init__.py diff --git a/polymatrix/statemonad/mixins/statemonadmixin.py b/polymatrix/statemonad/mixins/statemonadmixin.py new file mode 100644 index 0000000..39b6576 --- /dev/null +++ b/polymatrix/statemonad/mixins/statemonadmixin.py @@ -0,0 +1,53 @@ +import abc +import dataclasses +from typing import Callable, Tuple, TypeVar, Generic +import typing + +State = TypeVar('State') +U = TypeVar('U') +V = TypeVar('V') + + +class StateMonadMixin( + Generic[State, U], + abc.ABC, +): + @property + @abc.abstractmethod + def apply_func(self) -> typing.Callable[[State], tuple[State, U]]: + ... + + # def init(func: Callable[[State], Tuple[State, U]]) -> 'StateMonadMixin[U, State]': + # class StateMonadImpl(StateMonadMixin): + # def apply(self, state: State) -> Tuple[State, U]: + # return func(state) + + # return StateMonadImpl() + + def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]': + + def internal_map(state: State) -> Tuple[State, U]: + n_state, val = self.apply(state) + return n_state, fn(val) + + return dataclasses.replace(self, apply_func=internal_map) + + def flat_map(self, fn: Callable[[U], 'StateMonadMixin']) -> 'StateMonadMixin[State, V]': + + def internal_map(state: State) -> Tuple[State, V]: + n_state, val = self.apply(state) + return fn(val).apply(n_state) + + return dataclasses.replace(self, apply_func=internal_map) + + def zip(self, other: 'StateMonadMixin') -> 'StateMonadMixin': + def internal_map(state: State) -> Tuple[State, V]: + state, val1 = self.apply(state) + state, val2 = other.apply(state) + return state, (val1, val2) + + return dataclasses.replace(self, apply_func=internal_map) + + # @abc.abstractmethod + def apply(self, state: State) -> Tuple[State, U]: + return self.apply_func(state) diff --git a/polymatrix/statemonad/statemonad.py b/polymatrix/statemonad/statemonad.py new file mode 100644 index 0000000..49ab1fa --- /dev/null +++ b/polymatrix/statemonad/statemonad.py @@ -0,0 +1,4 @@ +from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin + +class StateMonad(StateMonadMixin): + pass |