From 5eb4a5016ab95ccce410900c1097761adf2538a0 Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Sat, 5 Mar 2022 15:21:17 +0100 Subject: add inequality constraint and KKT conditions --- polymatrix/impl/optimizationimpl.py | 13 + polymatrix/impl/optimizationstateimpl.py | 11 + polymatrix/impl/polymatriximpl.py | 15 + polymatrix/init/initoptimization.py | 16 + polymatrix/init/initoptimizationstate.py | 14 + polymatrix/init/initpolymatrix.py | 36 +++ polymatrix/mixins/optimizationmixin.py | 435 ++++++++++++++++++++++++++ polymatrix/mixins/optimizationpipeopmixin.py | 12 + polymatrix/mixins/optimizationstatemixin.py | 78 +++++ polymatrix/mixins/polymatrixmixin.py | 54 ++++ polymatrix/optimization.py | 5 + polymatrix/optimizationstate.py | 5 + polymatrix/polymatrix.py | 5 + polymatrix/polystruct.py | 439 +++++++++++---------------- polymatrix/statemonad.py | 33 ++ polymatrix/utils.py | 4 + 16 files changed, 905 insertions(+), 270 deletions(-) create mode 100644 polymatrix/impl/optimizationimpl.py create mode 100644 polymatrix/impl/optimizationstateimpl.py create mode 100644 polymatrix/impl/polymatriximpl.py create mode 100644 polymatrix/init/initoptimization.py create mode 100644 polymatrix/init/initoptimizationstate.py create mode 100644 polymatrix/init/initpolymatrix.py create mode 100644 polymatrix/mixins/optimizationmixin.py create mode 100644 polymatrix/mixins/optimizationpipeopmixin.py create mode 100644 polymatrix/mixins/optimizationstatemixin.py create mode 100644 polymatrix/mixins/polymatrixmixin.py create mode 100644 polymatrix/optimization.py create mode 100644 polymatrix/optimizationstate.py create mode 100644 polymatrix/polymatrix.py create mode 100644 polymatrix/statemonad.py diff --git a/polymatrix/impl/optimizationimpl.py b/polymatrix/impl/optimizationimpl.py new file mode 100644 index 0000000..ac78ea0 --- /dev/null +++ b/polymatrix/impl/optimizationimpl.py @@ -0,0 +1,13 @@ +import dataclass_abc + +from polymatrix.optimization import Optimization +from polymatrix.optimizationstate import OptimizationState + + +@dataclass_abc.dataclass_abc(frozen=True, eq=False) +class OptimizationImpl(Optimization): + # n_var: int + state: OptimizationState + equality_constraints: dict[int, dict[tuple[int, int], float]] + inequality_constraints: dict[int, dict[tuple[int, int], float]] + auxillary_equality: dict[int, dict[tuple[int, int], float]] diff --git a/polymatrix/impl/optimizationstateimpl.py b/polymatrix/impl/optimizationstateimpl.py new file mode 100644 index 0000000..5ce8d11 --- /dev/null +++ b/polymatrix/impl/optimizationstateimpl.py @@ -0,0 +1,11 @@ +import dataclass_abc +from polymatrix.optimizationstate import OptimizationState +from polymatrix.polymatrix import PolyMatrix + + +@dataclass_abc.dataclass_abc(frozen=True, eq=False) +class OptimizationStateImpl(OptimizationState): + n_var: int + n_param: int + offset_dict: dict[tuple[PolyMatrix, int], int] + diff --git a/polymatrix/impl/polymatriximpl.py b/polymatrix/impl/polymatriximpl.py new file mode 100644 index 0000000..cc435e3 --- /dev/null +++ b/polymatrix/impl/polymatriximpl.py @@ -0,0 +1,15 @@ +import typing +import dataclass_abc + +from polymatrix.polymatrix import PolyMatrix + + +@dataclass_abc.dataclass_abc(frozen=True, eq=False) +class PolyMatrixImpl(PolyMatrix): + name: str + degrees: list[int] + subs: dict[int, dict[tuple[int, int], float]] + re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] + is_constant: bool + shape: tuple[int, int] + is_negated: bool diff --git a/polymatrix/init/initoptimization.py b/polymatrix/init/initoptimization.py new file mode 100644 index 0000000..c08c71f --- /dev/null +++ b/polymatrix/init/initoptimization.py @@ -0,0 +1,16 @@ +from polymatrix.impl.optimizationimpl import OptimizationImpl +from polymatrix.optimizationstate import OptimizationState + + +def init_optimization( + state: OptimizationState, + equality_constraints: dict[int, dict[tuple[int, int], float]], + inequality_constraints: dict[int, dict[tuple[int, int], float]], + auxillary_equality: dict[int, dict[tuple[int, int], float]] +): + return OptimizationImpl( + state=state, + equality_constraints=equality_constraints, + inequality_constraints=inequality_constraints, + auxillary_equality=auxillary_equality, + ) diff --git a/polymatrix/init/initoptimizationstate.py b/polymatrix/init/initoptimizationstate.py new file mode 100644 index 0000000..e6e323d --- /dev/null +++ b/polymatrix/init/initoptimizationstate.py @@ -0,0 +1,14 @@ +from polymatrix.impl.optimizationstateimpl import OptimizationStateImpl +from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin + + +def init_optimization_state( + n_var: int, + n_param: int, + offset_dict: dict[tuple[PolyMatrixMixin, int], int], +): + return OptimizationStateImpl( + n_var=n_var, + n_param=n_param, + offset_dict=offset_dict, + ) diff --git a/polymatrix/init/initpolymatrix.py b/polymatrix/init/initpolymatrix.py new file mode 100644 index 0000000..edfe195 --- /dev/null +++ b/polymatrix/init/initpolymatrix.py @@ -0,0 +1,36 @@ +import typing + +from polymatrix.impl.polymatriximpl import PolyMatrixImpl + + +def init_poly_matrix( + name: str, + shape: tuple[int, int], + degrees: list[int] = None, + is_constant: bool = None, + subs: dict[int, dict[tuple[int, int], float]] = None, + re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None, + is_negated: bool = None, +): + if degrees is None: + assert isinstance(subs, dict) + degrees = list(subs.keys()) + + if is_constant is None: + if subs is None: + is_constant = False + else: + is_constant = True + + if is_negated is None: + is_negated = False + + return PolyMatrixImpl( + name=name, + degrees=degrees, + subs=subs, + re_index=re_index, + is_constant=is_constant, + shape=shape, + is_negated=is_negated, + ) diff --git a/polymatrix/mixins/optimizationmixin.py b/polymatrix/mixins/optimizationmixin.py new file mode 100644 index 0000000..d9ea263 --- /dev/null +++ b/polymatrix/mixins/optimizationmixin.py @@ -0,0 +1,435 @@ +import abc +import collections +import dataclasses +import functools +import itertools + +import more_itertools +from polymatrix.mixins.optimizationpipeopmixin import OptimizationPipeOpMixin +from polymatrix.mixins.optimizationstatemixin import OptimizationStateMixin + +from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin +from polymatrix.polystruct import DegreeType +from polymatrix.utils import monomial_to_index + + +class OptimizationMixin(abc.ABC): + @property + @abc.abstractmethod + def state(self) -> OptimizationStateMixin: + ... + + @property + @abc.abstractmethod + def equality_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... + + @property + @abc.abstractmethod + def inequality_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... + + @property + @abc.abstractmethod + def auxillary_equality(self) -> dict[int, dict[tuple[int, int], float]]: + ... + + @property + def n_var(self) -> int: + return self.state.n_var + + @property + def n_equality_constraints(self) -> int: + if len(self.equality_constraints) == 0: + return 0 + + return max(eq for eq, _ in self.equality_constraints.items()) + 1 + + @property + def n_inequality_constraints(self) -> int: + if len(self.inequality_constraints) == 0: + return 0 + + return max(eq for eq, _ in self.inequality_constraints.items()) + 1 + # return max(key[0] for _, d_data in self.inequality_constraints.items() for key, _ in d_data.items()) + 1 + + @property + def n_auxillary_equality(self) -> int: + if len(self.auxillary_equality) == 0: + return 0 + + return max(eq for eq, _ in self.auxillary_equality.items()) + 1 + # return max(key[0] for _, d_data in self.auxillary_equality.items() for key, _ in d_data.items()) + 1 + + # def pipe(self, *operators: OptimizationPipeOpMixin): + # return functools.reduce(lambda obs, op: op(obs), operators, self) + + def add_equality_constraints(self, expr: tuple[tuple[PolyMatrixMixin, ...], ...]): + all_polymats = tuple(polymat for term in expr for polymat in term) + + # update offsets with unseen polymats + state = self.state.update_offsets(all_polymats) + + eq_constr_buffer = collections.defaultdict(lambda: collections.defaultdict(float)) + + for term in expr: + + for degrees in itertools.product(*(polymat.degrees for polymat in term)): + + total_degree = sum(degrees) + + # precompute substitution/offsets for all polynomial matrices in term + d_subs = [polymat.subs[degree] if polymat.subs is not None and degree in polymat.subs else None for degree, polymat in zip(degrees, term)] + d_offsets = [state.offset_dict.get((polymat, degree), (0, 0)) for polymat, degree in zip(term, degrees)] + + # n_var = 2, degree = 3 + # cominations = [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)] + for combination in itertools.combinations_with_replacement(range(self.n_var), total_degree): + for x_monom in more_itertools.distinct_permutations(combination): + + def acc_func(acc, v): + last, _ = acc + new = last + v + + return new, x_monom[last:new] + + # monom=(1, 0, 1) -> monom1=(x2, x1), monom2=(x2) + # the monomial either refers to a free parameter, or to a value in the substitution table + p_monoms = list(monom for _, monom in itertools.accumulate(degrees, acc_func, initial=(0, None)))[1:] + + def non_increasing(seq): + return all(y <= x for x, y in zip(seq, seq[1:])) + + # (1,0) -> x2*x1 instead of (0,1)->x1*x2 + if all(non_increasing(monom) for monom in p_monoms): + + col_defaults = [monomial_to_index(self.n_var, monom) for monom in p_monoms] + + n_rows_col = itertools.chain((m.shape[0] for m in term), (term[-1].shape[1],)) + + for curr_rows_col in itertools.product(*[range(e) for e in n_rows_col]): + + def gen_re_indexing(): + for m, degree, (poly_row, poly_col), monom, col_default, subs, offset in zip( + term, degrees, itertools.pairwise(curr_rows_col), p_monoms, col_defaults, d_subs, d_offsets): + + if m.re_index is not None: + re_index = m.re_index(degree, poly_row, poly_col, monom) + else: + re_index = None + + if re_index is None: + col = col_default + new_poly_row = poly_row + new_poly_col = poly_col + factor = 1.0 + + else: + new_poly_row, new_poly_col, new_monom, factor = re_index + col = monomial_to_index(self.n_var, new_monom) + + if subs is not None: + try: + subs_val = subs[(new_poly_row, new_poly_col, col)] + except KeyError: + subs_val = None + else: + subs_val = None + + if subs_val is None: + # the coefficient is selected after the reindexing + + row = new_poly_row + new_poly_col * m.shape[0] + + # linearize parameter matrix + p_index = int(offset[0] + row + col * m.shape[0] * m.shape[1]) + + assert p_index < offset[1], f'{p_index} is bigger than {offset[1]}' + + else: + p_index = None + + yield poly_row, poly_col, p_index, factor, subs_val + + data = tuple(gen_re_indexing()) + + total_factor = functools.reduce(lambda x, y: x*y, (d[3] for d in data)) + + if total_factor == 0: + continue + + value = functools.reduce(lambda x, y: x*y, (d[4] for d in data if d[4] is not None), 1) * total_factor + + if value == 0: + continue + + poly_row = data[0][0] + p_monom = tuple(d[2] for d in data if d[4] is None) + degree = len(p_monom) + + eq_constr_buffer[degree][poly_row, x_monom, p_monom] += value + + # assign equations + rows_perm_set = set((eq_idx, perm) for eq_tuple_degree in eq_constr_buffer.values() for (eq_idx, perm, _) in eq_tuple_degree.keys()) + eq_to_rows = {eq: idx for idx, eq in enumerate(rows_perm_set)} + + eq_constr = collections.defaultdict(list) + + for degree, d_data in eq_constr_buffer.items(): + for (eq_idx, perm, p_monoms), val in d_data.items(): + row = eq_to_rows[(eq_idx, perm)] + eq_constr[row] += ((p_monoms, val),) + + # eq_data = dict(gen_eq_data()) + + return dataclasses.replace(self, state=state, equality_constraints=eq_constr) + + def add_inequality_constraints(self, expr): + # all_polymats = tuple(polymat for term in expr for polymat in term) + polymat = expr[0][0] + all_polymats = (polymat,) + + # update offsets with unseen polymats + state = self.state.update_offsets(all_polymats) + + aux_eq_buffer = collections.defaultdict(list) + ineq_constr_buffer = collections.defaultdict(list) + + ineq_idx = self.n_inequality_constraints + aux_eq_idx = self.n_auxillary_equality + + for degree in polymat.degrees: + # print(f'{degree=}') + + assert degree != 0 + assert degree % 2 == 0 + + vec = tuple(itertools.combinations_with_replacement(range(state.n_var), int(degree/2))) + n_square_mat = len(vec) + + # introduce auxillary parameters + offset_x = state.n_param + # print(f'{offset_x=}') + + n_param_added = int(n_square_mat*(n_square_mat-1)/2) + state = state.update_param(n_param_added) + # print(f'{n_param_added=}') + + offset_a, end_idx = state.offset_dict[(polymat, degree)] + + def get_param_index_of_poly_mat(row, col): + x_monom = vec[row] + vec[col] + reindex = polymat.re_index(degree, 1, 1, x_monom) + + if reindex is None: + n_monom, factor = x_monom, 1.0 + + else: + _, _, n_monom, factor = polymat.re_index(degree, 1, 1, vec[row] + vec[col]) + + # assert all(e < end_idx for e in n_monom) + + new_poly_row = 0 + new_poly_col = 0 + col = monomial_to_index(self.n_var, n_monom) + row = new_poly_row + new_poly_col * polymat.shape[0] + + # linearize parameter matrix + p_index = int(offset_a + row + col * polymat.shape[0] * polymat.shape[1]) + + return (p_index,), factor + + for k in range(n_square_mat): + # f in f-v^T@x-r^2 + monom, factor = get_param_index_of_poly_mat(k, k) + # ineq_constr_buffer[len(monom)][(ineq_idx, monom)] += factor + ineq_constr_buffer[ineq_idx] += ((monom, factor),) + + # print(ineq_constr_buffer) + + for row in range(k): + + # v^T@x in f-v^T@x-r^2 + monom, factor = get_param_index_of_poly_mat(k, row) + # ineq_constr_buffer[len(monom) + 1][(ineq_idx, monom + (offset_x + row,))] -= factor + ineq_constr_buffer[ineq_idx] += ((monom + (offset_x + row,), -factor),) + + # print(f'{k}: {offset_x + row=}') + + for col in range(k): + # P@x in P@x-v + monom, factor = get_param_index_of_poly_mat(row, col) + # aux_eq_buffer[len(monom) + 1][(aux_eq_idx+row, monom + (offset_x + col,))] += factor + aux_eq_buffer[aux_eq_idx] += ((monom + (offset_x + col,), factor),) + + # -v in P@x-v + monom, factor = get_param_index_of_poly_mat(row, k) + # aux_eq_buffer[len(monom)][(aux_eq_idx+row, monom)] -= factor + aux_eq_buffer[aux_eq_idx] += ((monom, -factor),) + + aux_eq_idx += 1 + + ineq_idx += 1 + offset_x += k + + # print(f'{ineq_idx - self.n_inequality_constraints + aux_eq_idx - self.n_auxillary_equality=}') + # print(f'{aux_eq_idx - self.n_auxillary_equality=}') + + return dataclasses.replace(self, inequality_constraints=ineq_constr_buffer, auxillary_equality=aux_eq_buffer, state=state) + + def minimize(self, cost_func=None): + """ + - assume sum of squares cost function on variables x + - introduce nu/lambda for each equality/inequality + - differentiate equality and inequality constraints for each variable + + > equality constraints + > inequality constraints - r1 + > x + nu * equality constraint + lambda * inequality constraint + > lambda - r2 + > r1 * r2 + > P @ x = v + """ + + state = self.state + + n_equality_constraints = self.n_equality_constraints + n_inequality_constraints = self.n_inequality_constraints + + # remove unused parameters + # ------------------------ + + monom_update_reverse = set(m for monom_vals in itertools.chain(self.equality_constraints.values(), self.inequality_constraints.values(), self.auxillary_equality.values()) for monom, _ in monom_vals for m in monom) + # print(tuple(p for p in param_indices if p not in monom_update_reverse)) + monom_update = {monom: idx for idx, monom in enumerate(sorted(monom_update_reverse))} + + assert max(monom_update_reverse) == state.n_param - 1, f'{max(monom_update_reverse)=} is not {state.n_param - 1=}' + + # param_indices = tuple(start + idx for start, end in state.offset_dict.values() for idx in range(end - start)) + # print(tuple(m for m in monom_update_reverse if m not in param_indices)) + # print(param_indices) + + # the variables x are assumed to be the parameters of all registered polynomial matrices + # todo: this would later come from a specific polynomial matrices + param_indices = tuple(monom_update[start + idx] for start, end in state.offset_dict.values() for idx in range(end - start) if start + idx in monom_update_reverse) + + equality_constraints = tuple((key, tuple((tuple(monom_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in self.equality_constraints.items()) + inequality_constraints = tuple((key, tuple((tuple(monom_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in self.inequality_constraints.items()) + auxillary_equality = tuple((key, tuple((tuple(monom_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in self.auxillary_equality.items()) + + # introduce variable nu for each equality/inequality + idx_nu = len(monom_update_reverse) + idx_lambda = idx_nu + n_equality_constraints + idx_r1 = idx_lambda + n_inequality_constraints + idx_r2 = idx_r1 + n_inequality_constraints + + # total_idx = idx_r2 + n_inequality_constraints + # print(f'{total_idx=}') + + eq_buffer = collections.defaultdict(list) + + current_eq_offset = 0 + + # > equality constraints + for eq, monom_val_list in equality_constraints: + eq_buffer[current_eq_offset + eq] += monom_val_list + + current_eq_offset += n_equality_constraints + assert max(eq_buffer.keys()) == current_eq_offset - 1 + + # > inequality constraints - r1 + for eq, monom_val_list in inequality_constraints: + eq_buffer[current_eq_offset + eq] += monom_val_list + (((idx_r1 + eq,), -1),) + + current_eq_offset += n_inequality_constraints + # assert max(eq_buffer.keys()) < current_eq_offset + assert max(eq_buffer.keys()) == current_eq_offset - 1 + + # > x + nu * equality constraint + lambda * inequality constraint + # --------------------------------------------------------------- + + for index, param in enumerate(param_indices): + eq_buffer[current_eq_offset + index] += (((param,), 1),) + + # assert max(eq_buffer.keys()) < current_eq_offset + n_inequality_constraints, f'{max(eq_buffer.keys())=} is bigger than {current_eq_offset + n_inequality_constraints=}' + + # differentiate equality constraints for each variable x + # def gen_derivative(): + for ineq_constr_idx, eq_data in equality_constraints: + for p_monom, val in eq_data: + + p_monom_grp = collections.defaultdict(int) + for m in p_monom: + p_monom_grp[m] += 1 + + for m, counter in p_monom_grp.items(): + if m in param_indices: + def generate_monom(): + for i_m, i_counter in p_monom_grp.items(): + if m is i_m: + sel_counter = i_counter - 1 + else: + sel_counter = i_counter + + for _ in range(sel_counter): + yield i_m + + eq_idx = param_indices.index(m) + der_monomial = tuple(generate_monom()) + (idx_nu + ineq_constr_idx,) + eq_buffer[current_eq_offset + eq_idx] += ((der_monomial, val * counter),) + + for eq_constr_idx, eq_data in inequality_constraints: + for p_monom, val in eq_data: + + p_monom_grp = collections.defaultdict(int) + for m in p_monom: + p_monom_grp[m] += 1 + + for m, counter in p_monom_grp.items(): + if m in param_indices: + def generate_monom(): + for i_m, i_counter in p_monom_grp.items(): + if m is i_m: + sel_counter = i_counter - 1 + else: + sel_counter = i_counter + + for _ in range(sel_counter): + yield i_m + + eq_idx = param_indices.index(m) + der_monomial = tuple(generate_monom()) + (idx_lambda + eq_constr_idx,) + eq_buffer[current_eq_offset + eq_idx] += ((der_monomial, val * counter),) + + current_eq_offset += len(param_indices) + # assert max(eq_buffer.keys()) < current_eq_offset, f'{max(eq_buffer.keys())=} is bigger than {current_eq_offset=}' + assert max(eq_buffer.keys()) == current_eq_offset - 1 + + # > lambda - r2 + for idx in range(n_inequality_constraints): + eq_buffer[current_eq_offset + idx] += (((idx_lambda + idx,), 1), ((idx_r2 + idx,), -1)) + + current_eq_offset += n_inequality_constraints + # assert max(eq_buffer.keys()) < current_eq_offset + assert max(eq_buffer.keys()) == current_eq_offset - 1 + + # > r1 * r2 + for idx in range(n_inequality_constraints): + eq_buffer[current_eq_offset + idx] += (((idx_r1 + idx, idx_r2 + idx), 1),) + + current_eq_offset += n_inequality_constraints + # assert max(eq_buffer.keys()) < current_eq_offset + assert max(eq_buffer.keys()) == current_eq_offset - 1 + + # > P @ x - v + for eq, data in auxillary_equality: + eq_buffer[current_eq_offset + eq] += data + + current_eq_offset += len(monom_update_reverse) - len(param_indices) + # print(f'{current_eq_offset=}') + # print(f'{len(monom_update_reverse) - len(param_indices)=}') + assert max(eq_buffer.keys()) == current_eq_offset - 1, f'{max(eq_buffer.keys())} is not {current_eq_offset - 1}' + + # print(f'{current_eq_offset=}') + + return eq_buffer, monom_update_reverse diff --git a/polymatrix/mixins/optimizationpipeopmixin.py b/polymatrix/mixins/optimizationpipeopmixin.py new file mode 100644 index 0000000..a90c193 --- /dev/null +++ b/polymatrix/mixins/optimizationpipeopmixin.py @@ -0,0 +1,12 @@ +import abc +import typing + + +class OptimizationPipeOpMixin(abc.ABC): + @property + @abc.abstractmethod + def func(self) -> typing.Callable[[typing.Any], typing.Any]: + ... + + def __call__(self, source: typing.Any): + return self.func(source) diff --git a/polymatrix/mixins/optimizationstatemixin.py b/polymatrix/mixins/optimizationstatemixin.py new file mode 100644 index 0000000..ae688bb --- /dev/null +++ b/polymatrix/mixins/optimizationstatemixin.py @@ -0,0 +1,78 @@ +import abc +import itertools +import dataclasses + +from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin +from polymatrix.utils import monomial_to_index + + +class OptimizationStateMixin(abc.ABC): + @property + @abc.abstractmethod + def n_var(self) -> int: + """ + dimension of x + """ + + ... + + @property + @abc.abstractmethod + def n_param(self) -> int: + """ + current number of parameters used in polynomial matrix expressions + """ + + ... + + @property + @abc.abstractmethod + def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], tuple[int, int]]: + ... + + # @property + # @abc.abstractmethod + # def local_index_dict(self) -> dict[tuple[PolyMatrixMixin, int], dict[int, int]]: + # ... + + def get_polymat(self, p_index: int) -> tuple[PolyMatrixMixin, int, int]: + for (polymat, degree), (start, end) in self.offset_dict.items(): + if start <= p_index < end: + return polymat, degree, p_index - start + + raise Exception(f'index {p_index} not found in offset dictionary') + + # def get_num_param_for(self, polymat: PolyMatrixMixin, degree: int) -> int: + # start_idx, end_idx = self.offset_dict[(polymat, degree)] + # return end_idx - start_idx + + def update_offsets(self, polymats: tuple[PolyMatrixMixin]) -> 'OptimizationStateMixin': + registered_polymats = set(polymat for polymat, _ in self.offset_dict.keys()) + # print(registered_polymats) + parametric_polymats = set(p for p in polymats if not p.is_constant and p not in registered_polymats) + # print(parametric_polymats) + + if len(parametric_polymats) == 0: + return self + + else: + + def gen_n_param_per_struct(): + for polymat in parametric_polymats: + for degree in polymat.degrees: + + # number of terms is given by the maximum index + 1 + number_of_terms = int(polymat.shape[0] * polymat.shape[1] * (monomial_to_index(self.n_var, degree*(self.n_var-1,)) + 1)) + + yield (polymat, degree), number_of_terms + + param_key, num_of_terms = tuple(zip(*gen_n_param_per_struct())) + cum_sum = tuple(itertools.accumulate((self.n_param,) + num_of_terms)) + offset_dict = dict(zip(param_key, itertools.pairwise(cum_sum))) + + return dataclasses.replace(self, offset_dict=offset_dict, n_param=cum_sum[-1]) + + def update_param(self, n): + return dataclasses.replace(self, n_param=self.n_param+n) + + diff --git a/polymatrix/mixins/polymatrixmixin.py b/polymatrix/mixins/polymatrixmixin.py new file mode 100644 index 0000000..c569a22 --- /dev/null +++ b/polymatrix/mixins/polymatrixmixin.py @@ -0,0 +1,54 @@ +import abc +import dataclasses +import typing + +from matplotlib.pyplot import streamplot + + +class PolyMatrixMixin(abc.ABC): + @property + @abc.abstractmethod + def name(self) -> str: + ... + + @property + @abc.abstractmethod + def degrees(self) -> list[int]: + ... + + @property + @abc.abstractmethod + def subs(self) -> dict[int, dict[tuple[int, int], float]]: + ... + + @property + @abc.abstractmethod + def re_index(self) -> typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]]: + ... + + @property + @abc.abstractmethod + def is_constant(self) -> int: + ... + + @property + @abc.abstractmethod + def shape(self) -> tuple[int, int]: + ... + + @property + @abc.abstractmethod + def is_negated(self) -> bool: + ... + + def __neg__(self): + return dataclasses.replace(self, is_negated=not self.is_negated) + + def get_parameter_name(self, degree: int, index: int) -> str: + n_square = self.shape[0] * self.shape[1] + param_index = int(index / n_square) + n_index = index - param_index * n_square + col_index = int(n_index / self.shape[0]) + row_index = n_index - col_index * self.shape[0] + + return f'{self.name}_({degree},{row_index},{col_index},{param_index})' diff --git a/polymatrix/optimization.py b/polymatrix/optimization.py new file mode 100644 index 0000000..c5034a1 --- /dev/null +++ b/polymatrix/optimization.py @@ -0,0 +1,5 @@ +from polymatrix.mixins.optimizationmixin import OptimizationMixin + + +class Optimization(OptimizationMixin): + pass diff --git a/polymatrix/optimizationstate.py b/polymatrix/optimizationstate.py new file mode 100644 index 0000000..775e976 --- /dev/null +++ b/polymatrix/optimizationstate.py @@ -0,0 +1,5 @@ +from polymatrix.mixins.optimizationstatemixin import OptimizationStateMixin + + +class OptimizationState(OptimizationStateMixin): + pass diff --git a/polymatrix/polymatrix.py b/polymatrix/polymatrix.py new file mode 100644 index 0000000..ff43af5 --- /dev/null +++ b/polymatrix/polymatrix.py @@ -0,0 +1,5 @@ +from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin + + +class PolyMatrix(PolyMatrixMixin): + pass diff --git a/polymatrix/polystruct.py b/polymatrix/polystruct.py index ad6cb74..808b7ab 100644 --- a/polymatrix/polystruct.py +++ b/polymatrix/polystruct.py @@ -45,146 +45,118 @@ class PolyMatrixMixin(abc.ABC): def shape(self) -> tuple[int, int]: ... -class EqualityConstraintMixin(abc.ABC): + +class State(abc.ABC): @property @abc.abstractmethod - def terms(self) -> dict[int, dict[tuple[int, tuple, int], float]]: + def n_param(self) -> int: ... @property @abc.abstractmethod - def n_param(self) -> int: + def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]: ... @property @abc.abstractmethod - def variable_to_index(self) -> typing.Callable[[int, tuple[int, ...]], int]: + def local_index_dict(self) -> dict[tuple[PolyMatrixMixin, int], dict[int, int]]: ... - @functools.cached_property - def eq_to_row_index(self): - rows_to_eq = list(set((eq_idx, perm) for eq_tuple_degree in self.terms.values() for (eq_idx, perm, var) in eq_tuple_degree.keys())) - eq_to_rows = {eq: idx for idx, eq in enumerate(rows_to_eq)} - return eq_to_rows - @property - def n_eq(self): - return len(self.eq_to_row_index) - - def get_constraint_matrix(self): - def gen_sparse_matrices(): - for degree, degree_tuples in self.terms.items(): - def gen_row_col_data(): - for (idx_eq, perm, variables), value in degree_tuples.items(): - row = self.eq_to_row_index[(idx_eq, perm)] - col = self.variable_to_index(self.n_param, variables) - yield row, col, value - - row, col, data = list(zip(*gen_row_col_data())) - - data = np.array(data, dtype=np.float) - - if degree <= 1: - yield degree, scipy.sparse.coo_array((data, (row, col)), shape=(self.n_eq, self.n_param**degree)).toarray() - else: - yield degree, scipy.sparse.coo_array((data, (row, col)), shape=(self.n_eq, self.n_param**degree)) - - return dict(gen_sparse_matrices()) - - def get_constraint_func(self): - def func(x): - mat = np.zeros((self.n_eq,)) - - for degree, degree_tuples in self.terms.items(): - if 0 == degree: - for (idx_eq, perm, variables), value in degree_tuples.items(): - row_idx = self.eq_to_row_index[(idx_eq, perm)] - mat[row_idx] += value + @abc.abstractmethod + def eq_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... - elif 0 < degree: - def gen_vector(): - for indices in itertools.combinations_with_replacement(range(self.n_param), degree): - yield np.prod(list(x[idx] for idx in indices)) - vector = list(gen_vector()) + @property + @abc.abstractmethod + def ineq_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... - for (idx_eq, perm, variables), value in degree_tuples.items(): - row_idx = self.eq_to_row_index[(idx_eq, perm)] - vector_val = vector[self.variable_to_index(self.n_param, variables)] - mat[row_idx] += value * vector_val - return mat - return func - def get_constraint_jacobian(self): - def func(x): - jac_mat = np.zeros((self.n_eq, self.n_param)) +class PolyExpressionMixin(abc.ABC): + @property + @abc.abstractmethod + def data(self) -> dict[int, dict[tuple[int, int], float]]: + ... - for degree, degree_tuples in self.terms.items(): - if 1 == degree: - for (idx_eq, perm, variables), value in degree_tuples.items(): - row_idx = self.eq_to_row_index[(idx_eq, perm)] - jac_mat[row_idx, variables[0]] += value + @property + @abc.abstractmethod + def n_var(self) -> int: + ... - # for var in variables: - # col_idx = variable_to_index(self.n_param, (var,)) - # jac_mat[row_idx, col_idx] += value + @property + @abc.abstractmethod + def n_eq(self): + ... - elif 1 < degree: - def gen_vector(): - for indices in itertools.combinations_with_replacement(range(self.n_param), degree-1): - yield np.prod(list(x[idx] for idx in indices)) - vector = list(gen_vector()) + # @property + # @abc.abstractmethod + # def n_param(self) -> int: + # ... - for (idx_eq, perm, variables), value in degree_tuples.items(): - row_idx = self.eq_to_row_index[(idx_eq, perm)] + # @property + # @abc.abstractmethod + # def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]: + # ... - for var_idx, var in enumerate(variables): - other_variables = variables[:var_idx] + variables[var_idx+1:] - vector_val = vector[self.variable_to_index(self.n_param, other_variables)] - # col_idx = variable_to_index(self.n_param, (var,)) - jac_mat[row_idx, var] += value*vector_val +class OptimizationMixin(abc.ABC): + @property + @abc.abstractmethod + def cost_function(self) -> dict[int, dict[int, float]]: + ... - return jac_mat - return func + @property + @abc.abstractmethod + def eq_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... - def get_constraint_hessian(self): - def func(x, v): - hess_mat = np.zeros((self.n_param, self.n_param)) + @property + @abc.abstractmethod + def ineq_constraints(self) -> dict[int, dict[tuple[int, int], float]]: + ... - for degree, degree_tuples in self.terms.items(): - if 2 == degree: - for (idx_eq, perm, variables), value in degree_tuples.items(): - eq_idx = self.eq_to_row_index[(idx_eq, perm)] + @property + @abc.abstractmethod + def n_var(self) -> int: + ... - for var_idx_x, var_x in enumerate(variables): - other_variables = variables[:var_idx_x] + variables[var_idx_x+1:] + # @property + # @abc.abstractmethod + # def n_param(self) -> int: + # ... - for var_idx_y, var_y in enumerate(other_variables): - hess_mat[var_x, var_y] = v[eq_idx]*value + @property + @abc.abstractmethod + def n_eq_constraints(self): + ... - elif 2 < degree: - def gen_vector(): - for indices in itertools.combinations_with_replacement(range(self.n_param), degree-2): - yield np.prod(list(x[idx] for idx in indices)) - vector = list(gen_vector()) + @property + @abc.abstractmethod + def n_ineq_constraints(self): + ... - for (idx_eq, perm, variables), value in degree_tuples.items(): - eq_idx = self.eq_to_row_index[(idx_eq, perm)] + # @property + # @abc.abstractmethod + # def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]: + # ... - for var_idx_x, var_x in enumerate(variables): - other_variables = variables[:var_idx_x] + variables[var_idx_x+1:] + def create(self) -> PolyExpressionMixin: + """ + - adds nu and lambda for each equality and inequality + - lambda - r_lambda = 0 for each inequality + - lambda r_ineq = 0 for each inequality + - add vanishing gradient + """ - for var_idx_y, var_y in enumerate(other_variables): - other_variables_2 = variables[:var_idx_y] + variables[var_idx_y+1:] - vector_val = vector[self.variable_to_index(self.n_param, other_variables_2)] - hess_mat[var_x, var_y] = v[eq_idx]*value*vector_val + pass - return hess_mat - return func + def add_positive_definiteness_condition(self, key): + pass -class PolyEquationMixin(abc.ABC): +class PolyMatrixEquationMixin(abc.ABC): @property @abc.abstractmethod def terms(self) -> list[tuple[PolyMatrixMixin, PolyMatrixMixin]]: @@ -207,14 +179,17 @@ class PolyEquationMixin(abc.ABC): @property @abc.abstractmethod - def variable_to_index(self) -> typing.Callable[[int, tuple[int, ...]], int]: + def monom_to_index(self) -> typing.Callable[[int, tuple[int, ...]], int]: ... - @functools.cached_property - def _param_list(self) -> list[tuple[PolyMatrixMixin, int], int]: - """ - used to determine the offset of the coefficients of each polynomial matrix - """ + def create( + self, + subs: dict[PolyMatrixMixin, dict[DegreeType, dict[int, int, float]]] = None, + ) -> PolyExpressionMixin: + if subs is None: + added_subs = {} + else: + added_subs = subs # create parameter offset all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term) @@ -227,70 +202,16 @@ class PolyEquationMixin(abc.ABC): continue for degree in struct.degrees: - number_of_terms = int(struct.shape[0] * struct.shape[1] * (self.variable_to_index(self.n_var, degree*(self.n_var-1,)) + 1)) - # number_of_terms = int(struct.shape[0] * struct.shape[1] * binom(self.n_var+degree-1, degree)) - - # print(f'{struct=}, {number_of_terms=}') - + number_of_terms = int(struct.shape[0] * struct.shape[1] * (self.monom_to_index(self.n_var, degree*(self.n_var-1,)) + 1)) yield (struct, degree), number_of_terms - param_list = list(gen_n_param_per_struct()) - - return param_list - - @functools.cached_property - def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]: - """ - determine the offset of the coefficients of each polynomial matrix ordered by degree - - The polynomial equation - - A * B = 0 - - is represented by a vector of coefficients `coeff`. Each coefficients is associated to polynomial matrix. - - For example - - offset_dict[(A,0)] = 12 - - means that the first coefficient a011 (meaning 011=degree+row+col) associated to A and degree 0 is located at index 12 of `coeff`. - """ - - param_key_value = list(zip(*self._param_list)) - - if 0 < len(param_key_value): - param_key, param_value = param_key_value - cum_sum = list(itertools.accumulate(param_value)) - offset_dict = dict(zip(param_key, [0] + cum_sum[:-1])) - else: - offset_dict = {} - - return offset_dict - - @functools.cached_property - def n_param(self) -> dict[tuple[PolyMatrixMixin, int], int]: - """ - number of coefficients of polynomial matrix equation, e.g. `len(coeff)` - """ - - if 0 < len(self._param_list): - *_, n_param = itertools.accumulate(e[1] for e in self._param_list) - else: - n_param = 0 + # param_list = list(gen_n_param_per_struct()) + param_key, param_value = list(zip(*gen_n_param_per_struct())) + cum_sum = list(itertools.accumulate(param_value)) + offset_dict = dict(zip(param_key, [0] + cum_sum[:-1])) - return n_param - - def create( - self, - subs: dict[PolyMatrixMixin, dict[DegreeType, dict[int, int, float]]] = None, - ) -> EqualityConstraintMixin: - if subs is None: - added_subs = {} - else: - added_subs = subs - - # create parameter offset - all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term) + # # create parameter offset + # all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term) def gen_substitutions(): for struct in all_structs: @@ -319,7 +240,7 @@ class PolyEquationMixin(abc.ABC): subs_dict = dict(gen_substitutions()) - terms = collections.defaultdict(lambda: collections.defaultdict(float)) + eq_constr_buffer = collections.defaultdict(lambda: collections.defaultdict(float)) def gen_re_indexing(term, degrees, curr_rows_col, monoms, col_defaults, d_subs, d_offsets, perm): for m, degree, (poly_row, poly_col), monom, col_default, subs, offset in zip( @@ -338,7 +259,7 @@ class PolyEquationMixin(abc.ABC): else: new_poly_row, new_poly_col, new_monom, factor = re_index - col = self.variable_to_index(self.n_var, new_monom) + col = self.monom_to_index(self.n_var, new_monom) if subs is not None: try: @@ -355,16 +276,10 @@ class PolyEquationMixin(abc.ABC): # linearize parameter matrix param_idx = int(offset + row + col * m.shape[0] * m.shape[1]) + else: param_idx = None - # if poly_row == 0 and perm == (0,) and monom == (0,): - # print(f'{new_poly_row=}') - # print(f'{new_poly_col=}') - # print(f'{col=}') - # print(f'{monom=}') - # print(f'{new_monom=}') - yield poly_row, poly_col, param_idx, factor, subs_val for term in self.terms: @@ -375,21 +290,21 @@ class PolyEquationMixin(abc.ABC): for degrees in itertools.product(*(m.degrees for m in term)): total_degree = sum(degrees) d_subs = [subs[degree] if subs is not None and degree in subs else None for degree, subs in zip(degrees, term_subs)] - d_offsets = [self.offset_dict.get((m, degree), 0) for m, degree in zip(term, degrees)] + d_offsets = [offset_dict.get((m, degree), 0) for m, degree in zip(term, degrees)] for combination in itertools.combinations_with_replacement(range(self.n_var), total_degree): # n_var = 2, degree = 3 # cominations = [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)] - for perm in more_itertools.distinct_permutations(combination): + for monom in more_itertools.distinct_permutations(combination): def acc_func(acc, v): last, _ = acc new = last + v - return new, perm[last:new] + return new, monom[last:new] - # perm=(1, 0, 1) -> monom1=(x2, x1), monom2=(x2) + # monom=(1, 0, 1) -> monom1=(x2, x1), monom2=(x2) monoms = list(monom for _, monom in itertools.accumulate(degrees, acc_func, initial=(0, None)))[1:] def non_increasing(seq): @@ -398,13 +313,13 @@ class PolyEquationMixin(abc.ABC): # (1,0) -> x2*x1 instead of (0,1)->x1*x2 if all(non_increasing(monom) for monom in monoms): - col_defaults = [self.variable_to_index(self.n_var, monom) for monom in monoms] + col_defaults = [self.monom_to_index(self.n_var, monom) for monom in monoms] n_rows_col = itertools.chain((m.shape[0] for m in term), (term[-1].shape[1],)) for curr_rows_col in itertools.product(*[range(e) for e in n_rows_col]): - data = tuple(gen_re_indexing(term, degrees, curr_rows_col, monoms, col_defaults, d_subs, d_offsets, perm)) + data = tuple(gen_re_indexing(term, degrees, curr_rows_col, monoms, col_defaults, d_subs, d_offsets, monom)) total_factor = functools.reduce(lambda x, y: x*y, (d[3] for d in data)) @@ -419,61 +334,60 @@ class PolyEquationMixin(abc.ABC): poly_row = data[0][0] param_idx = tuple(d[2] for d in data if d[4] is None) degree = len(param_idx) - - # if poly_row == 0 and perm == (0,): - # print(f'{param_idx=}') - # print(f'{data=}') - # print(f'{monoms=}') - # print(f'{total_factor=}') - - terms[degree][poly_row, perm, param_idx] += value - - return EqualityConstraintImpl( - terms=terms, - n_param=self.n_param, - variable_to_index=self.variable_to_index, - ) - - def matrix_to_poly(self, struct, x, param, tol=None): - assert len(x) == self.n_var, f'variable {x} needs to be of length {self.n_var}' - n_var_2 = self.n_var**2 - - if struct.is_vector: - n_col = 1 - - else: - n_col = self.n_var - - sym_expr = [[0 for _ in range(n_col)] for _ in range(self.n_var)] - - for degree in struct.degrees: - offset = self.offset_dict[(struct, degree)] - # number_of_terms = int(binom(self.n_var+degree-1, degree)) - - def write_to_expr(row, col, val, term=1): - if tol is None or val <= -tol or tol <= val: - sym_expr[row][col] += val*term - - if 0 == degree: - for row in range(self.n_var): - for col in range(n_col): - write_to_expr(row, col, param[offset + row + col * self.n_var]) - - else: - def gen_vector(): - for comb in itertools.combinations_with_replacement(range(self.n_var), degree): - *_, last = itertools.accumulate(comb, lambda acc, idx: acc*x[idx], initial=1) - yield last - vector = list(gen_vector()) - - for row in range(self.n_var): - for col in range(n_col): - for idx, term in enumerate(vector): - # print(f'{offset + (row + col * self.n_var) * number_of_terms + idx=}, {param[offset + (row + col * self.n_var) * number_of_terms + idx]=}') - write_to_expr(row, col, param[offset + row + col * self.n_var + idx * n_var_2], term) - - return sym_expr + eq_constr_buffer[degree][poly_row, monom, param_idx] += value + + # assign equations + rows_perm_set = set((eq_idx, perm) for eq_tuple_degree in eq_constr_buffer.values() for (eq_idx, perm, _) in eq_tuple_degree.keys()) + eq_to_rows = {eq: idx for idx, eq in enumerate(rows_perm_set)} + + # # calculate offset + # monom_update_reverse = sorted(set(m for eq_tuple_degree in eq_constr_buffer.values() for (_, _, monoms) in eq_tuple_degree.keys() for m in monoms)) + # monom_update = {monom: idx for idx, monom in enumerate(monom_update_reverse)} + + # def gen_n_cum_sum(): + # groups = itertools.groupby(monom_update_reverse, lambda v: next((i for i, g in enumerate(cum_sum) if v < g))) + # for _, group in groups: + # yield sum(1 for _ in group) + + # n_cum_sum = list(itertools.accumulate(gen_n_cum_sum())) + # n_offset_dict = dict(zip(param_key, [0] + n_cum_sum[:-1])) + # n_param = n_cum_sum[-1] + + # def gen_eq_data(): + # for degree, d_data in eq_constr_buffer.items(): + # def gen_eq_degree_data(): + # for (eq_idx, perm, monoms), val in d_data.items(): + # row = eq_to_rows[(eq_idx, perm)] + # monoms_updated = tuple(monom_update[m] for m in monoms) + # # print(f'{n_param=}') + # col = self.monom_to_index(n_param, monoms_updated) + # yield (row, col), val + + # yield degree, dict(gen_eq_degree_data()) + + # eq_data = dict(gen_eq_data()) + + def gen_eq_data(): + for degree, d_data in eq_constr_buffer.items(): + def gen_eq_degree_data(): + for (eq_idx, perm, monoms), val in d_data.items(): + row = eq_to_rows[(eq_idx, perm)] + # monoms_updated = tuple(monom_update[m] for m in monoms) + # col = self.monom_to_index(n_param, monoms_updated) + yield (row, monoms), val + + yield degree, dict(gen_eq_degree_data()) + + eq_data = dict(gen_eq_data()) + + return PolyEquationImpl( + data=eq_data, + n_param=cum_sum[-1], + n_eq=len(eq_to_rows), + n_var=self.n_var, + offset_dict=offset_dict, + ) ######################################## # Classes @@ -483,11 +397,11 @@ class PolyMatrix(PolyMatrixMixin): pass -class EqualityConstraint(EqualityConstraintMixin): +class PolyEquation(PolyExpressionMixin): pass -class PolyEquation(PolyEquationMixin): +class PolyMatrixEquation(PolyMatrixEquationMixin): pass @@ -505,39 +419,24 @@ class PolyMatrixImpl(PolyMatrix): @dataclass_abc.dataclass_abc(frozen=True) -class EqualityConstraintImpl(EqualityConstraintMixin): - terms: dict[int, dict[tuple[int, tuple, int], float]] - variable_to_index: typing.Callable[[int, tuple[int, ...]], int] +class PolyEquationImpl(PolyEquation): + data: dict[int, dict[tuple[int, tuple, int], float]] + offset_dict: dict[tuple[PolyMatrixMixin, int], int] n_param: int + n_eq: int + n_var: int @dataclass_abc.dataclass_abc(frozen=True) -class EquationImpl(PolyEquation): +class PolyMatrixEquationImpl(PolyMatrixEquation): terms: list[tuple[PolyMatrix, PolyMatrix]] - variable_to_index: typing.Callable[[int, tuple[int, ...]], int] + monom_to_index: typing.Callable[[int, tuple[int, ...]], int] n_var: int ######################################## # init functions ######################################## -# def init_poly_vector( -# n_row: int, -# degrees: list[int] = None, -# subs: dict[int, dict[tuple[int, int], float]] = None, -# re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None, -# is_constant: bool = None, -# ): -# return init_poly_matrix( -# degrees=degrees, -# subs=subs, -# re_index=re_index, -# shape=(n_row, , -# n_col=1, -# is_constant=is_constant, -# ) - - def init_poly_matrix( shape: tuple[int, int], degrees: list[int] = None, @@ -567,15 +466,15 @@ def init_poly_matrix( def init_equation( n_var: int, terms: list[tuple[PolyMatrix, PolyMatrix]], - variable_to_index: typing.Callable[[int, tuple[int, ...]], int] = None, + monom_to_index: typing.Callable[[int, tuple[int, ...]], int] = None, ): # assert all(not left.is_vector and right.is_vector for left, right in terms) - if variable_to_index is None: - variable_to_index = polymatrix.utils.variable_to_index + if monom_to_index is None: + monom_to_index = polymatrix.utils.variable_to_index - return EquationImpl( + return PolyMatrixEquationImpl( n_var=n_var, terms=terms, - variable_to_index=variable_to_index, + monom_to_index=monom_to_index, ) diff --git a/polymatrix/statemonad.py b/polymatrix/statemonad.py new file mode 100644 index 0000000..9be3b61 --- /dev/null +++ b/polymatrix/statemonad.py @@ -0,0 +1,33 @@ +from typing import Callable, Tuple, Any, TypeVar, Generic + +State = TypeVar('State') +U = TypeVar('U') +V = TypeVar('V') + + +class StateMonad(Generic[U, State]): + def __init__(self, fn: Callable[[State], Tuple[U, State]]) -> None: + self._fn = fn + + @classmethod + def unit(cls, value: Any) -> 'StateMonad[U, State]': + return cls(lambda state: (value, state)) + + def map(self, fn: Callable[[U], V]) -> 'StateMonad[V, State]': + def internal_map(state: State) -> Tuple[U, State]: + val, n_state = self._fn(state) + return fn(val), n_state + + return StateMonad(internal_map) + + def flat_map(self, fn: Callable[[U], 'StateMonad']) -> 'StateMonad[V, State]': + + def internal_map(state: State) -> Tuple[V, State]: + val, n_state = self._fn(state) + + return fn(val).run(n_state) + + return StateMonad(internal_map) + + def run(self, state: Any) -> Tuple[Any, Any]: + return self._fn(state) diff --git a/polymatrix/utils.py b/polymatrix/utils.py index e53db35..2595bf6 100644 --- a/polymatrix/utils.py +++ b/polymatrix/utils.py @@ -1,6 +1,10 @@ import scipy.special +def monomial_to_index(n_var, monomial): + return sum(idx*(n_var**level) for level, idx in enumerate(monomial)) + + def variable_to_index(n_var, combination): """ example: -- cgit v1.2.1