From 6a9ca2e0c327281c1359b6152f6afc9165e6bf9b Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Tue, 27 Feb 2024 07:39:52 +0100 Subject: improve typing with PolynomialData and MonomialData --- polymatrix/__init__.py | 519 +-------------------- polymatrix/denserepr/from_.py | 8 +- polymatrix/denserepr/utils/monomialtoindex.py | 11 + .../expression/mixins/combinationsexprmixin.py | 3 +- .../expression/mixins/derivativeexprmixin.py | 14 +- .../expression/mixins/divergenceexprmixin.py | 13 +- polymatrix/expression/mixins/elemmultexprmixin.py | 2 +- .../expression/mixins/linearmonomialsexprmixin.py | 2 +- .../expression/mixins/matrixmultexprmixin.py | 2 +- polymatrix/expression/mixins/productexprmixin.py | 4 +- .../expression/mixins/quadraticinexprmixin.py | 2 +- .../mixins/quadraticmonomialsexprmixin.py | 2 +- .../expression/mixins/substituteexprmixin.py | 2 +- .../mixins/subtractmonomialsexprmixin.py | 4 +- .../expression/utils/getderivativemonomials.py | 13 +- polymatrix/expression/utils/getmonomialindices.py | 5 +- polymatrix/expression/utils/gettupleremainder.py | 9 - polymatrix/expression/utils/getvariableindices.py | 50 +- .../expression/utils/mergemonomialindices.py | 22 - polymatrix/expression/utils/monomialtoindex.py | 7 - polymatrix/expression/utils/multiplymonomials.py | 15 - polymatrix/expression/utils/multiplypolynomial.py | 28 -- polymatrix/expression/utils/sortmonomialindices.py | 5 - polymatrix/expression/utils/sortmonomials.py | 5 - .../expression/utils/splitmonomialindices.py | 24 - .../expression/utils/subtractmonomialindices.py | 23 - polymatrix/polymatrix/impl.py | 7 +- polymatrix/polymatrix/init.py | 5 +- polymatrix/polymatrix/mixins.py | 14 +- polymatrix/polymatrix/typing.py | 8 +- .../polymatrix/utils/mergemonomialindices.py | 38 ++ polymatrix/polymatrix/utils/multiplypolynomial.py | 32 ++ polymatrix/polymatrix/utils/sortmonomialindices.py | 10 + polymatrix/polymatrix/utils/sortmonomials.py | 9 + .../polymatrix/utils/splitmonomialindices.py | 29 ++ .../polymatrix/utils/subtractmonomialindices.py | 28 ++ 36 files changed, 225 insertions(+), 749 deletions(-) create mode 100644 polymatrix/denserepr/utils/monomialtoindex.py delete mode 100644 polymatrix/expression/utils/gettupleremainder.py delete mode 100644 polymatrix/expression/utils/mergemonomialindices.py delete mode 100644 polymatrix/expression/utils/monomialtoindex.py delete mode 100644 polymatrix/expression/utils/multiplymonomials.py delete mode 100644 polymatrix/expression/utils/multiplypolynomial.py delete mode 100644 polymatrix/expression/utils/sortmonomialindices.py delete mode 100644 polymatrix/expression/utils/sortmonomials.py delete mode 100644 polymatrix/expression/utils/splitmonomialindices.py delete mode 100644 polymatrix/expression/utils/subtractmonomialindices.py create mode 100644 polymatrix/polymatrix/utils/mergemonomialindices.py create mode 100644 polymatrix/polymatrix/utils/multiplypolynomial.py create mode 100644 polymatrix/polymatrix/utils/sortmonomialindices.py create mode 100644 polymatrix/polymatrix/utils/sortmonomials.py create mode 100644 polymatrix/polymatrix/utils/splitmonomialindices.py create mode 100644 polymatrix/polymatrix/utils/subtractmonomialindices.py diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 6592238..45968aa 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -2,541 +2,24 @@ from polymatrix.expressionstate.abc import ExpressionState as internal_Expressio from polymatrix.expressionstate.init import init_expression_state as internal_init_expression_state from polymatrix.expression.expression import Expression as internal_Expression from polymatrix.expression.from_ import from_ as internal_from -# from polymatrix.expression.from_ import from_sympy as internal_from_sympy from polymatrix.expression import v_stack as internal_v_stack from polymatrix.expression import h_stack as internal_h_stack from polymatrix.expression import product as internal_product -from polymatrix.denserepr.from_ import from_polymatrix -# from polymatrix.expression.to import shape as internal_shape from polymatrix.expression.to import to_constant as internal_to_constant -# from polymatrix.expression.to import to_degrees as internal_to_degrees from polymatrix.expression.to import to_sympy as internal_to_sympy +from polymatrix.denserepr.from_ import from_polymatrix Expression = internal_Expression ExpressionState = internal_ExpressionState init_expression_state = internal_init_expression_state from_ = internal_from -# from_sympy = internal_from_sympy v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product -# to_shape = internal_shape to_constant_repr = internal_to_constant to_constant = internal_to_constant -# to_degrees = internal_to_degrees to_sympy_repr = internal_to_sympy to_sympy = internal_to_sympy to_matrix_repr = from_polymatrix to_dense = from_polymatrix - -# def from_sympy( -# data: tuple[tuple[float]], -# ): -# return init_expression( -# init_from_sympy_expr(data) -# ) - -# def from_state_monad( -# data: StateMonad, -# ): -# return init_expression( -# data.flat_map(lambda inner_data: init_from_sympy_expr(inner_data)), -# ) - -# def from_polymatrix( -# polymatrix: PolyMatrix, -# ): -# return init_expression( -# init_from_terms_expr(polymatrix) -# ) - -# def from_( -# data: str | tuple[tuple[float]], -# ): -# if isinstance(data, str): -# return from_(1).parametrize(data) - -# return from_sympy(data) - -# def v_stack( -# expressions: tuple[Expression], -# ): -# return init_expression( -# init_v_stack_expr(expressions) -# ) - -# def h_stack( -# expressions: tuple[Expression], -# ): -# return init_expression( -# init_v_stack_expr(tuple(expr.T for expr in expressions)) -# ).T - -# def block_diag( -# expressions: tuple[Expression], -# ): -# return init_expression( -# init_block_diag_expr(expressions) -# ) - -# def eye( -# variable: tuple[Expression], -# ): -# return init_expression( -# init_eye_expr(variable=variable) -# ) - -# def kkt_equality( -# equality: Expression, -# variables: Expression, -# name: str, -# ): -# equality_der = equality.diff( -# variables, -# introduce_derivatives=True, -# ) - -# nu = equality_der[0, :].T.parametrize(name) - -# return nu, equality_der @ nu - -# def kkt_inequality( -# inequality: Expression, -# variables: Expression, -# name: str, -# ): - -# inequality_der = inequality.diff( -# variables, -# introduce_derivatives=True, -# ) - -# nu = inequality_der[0, :].T.parametrize(f'lambda_{name}') -# r_nu = inequality_der[0, :].T.parametrize(f'r_lambda_{name}') -# r_ineq = inequality_der[0, :].T.parametrize(f'r_ineq_{name}') - -# prim_expr = inequality - r_ineq * r_ineq - -# sec_expr = nu - r_nu * r_nu - -# compl_expr = nu * inequality - -# cost_expr = inequality_der @ nu - -# return h_stack((nu, r_nu, r_ineq)), (cost_expr, prim_expr, sec_expr, compl_expr) - - -# def rows( -# expr: Expression, -# ) -> StateMonadMixin[ExpressionState, np.ndarray]: - -# def func(state: ExpressionState): -# state, underlying = expr.apply(state) - -# def gen_row_terms(): -# for row in range(underlying.shape[0]): - -# terms = {} - -# for col in range(underlying.shape[1]): -# polynomial = underlying.get_poly(row, col) -# if polynomial == None: -# continue - -# terms[0, col] = polynomial - -# yield init_expression(underlying=init_from_terms_expr( -# terms=terms, -# shape=(1, underlying.shape[1]) -# )) - -# row_terms = tuple(gen_row_terms()) - -# return state, row_terms - -# return init_state_monad(func) - - -# @dataclasses.dataclass -# class MatrixBuffer: -# data: dict[int, np.ndarray] -# n_row: int -# n_param: int - -# def get_max_degree(self): -# return max(degree for degree in self.data.keys()) - -# def add_buffer(self, index: int): -# if index <= 1: -# buffer = np.zeros((self.n_row, self.n_param**index), dtype=np.double) - -# else: -# buffer = scipy.sparse.dok_array((self.n_row, self.n_param**index), dtype=np.double) - -# self.data[index] = buffer - -# def add(self, row: int, col: int, index: int, value: float): -# if index not in self.data: -# self.add_buffer(index) - -# self.data[index][row, col] = value - -# def __getitem__(self, key): -# if key not in self.data: -# self.add_buffer(key) - -# return self.data[key] - - -# @dataclasses.dataclass -# class MatrixRepresentations: -# data: tuple[MatrixBuffer, ...] -# aux_data: typing.Optional[MatrixBuffer] -# variable_mapping: tuple[int, ...] -# state: ExpressionState - -# def merge_matrix_equations(self): -# def gen_matrices(index: int): -# for equations in self.data: -# if index < len(equations): -# yield equations[index] - -# if index < len(self.aux_data): -# yield self.aux_data[index] - -# indices = set(key for equations in self.data + (self.aux_data,) for key in equations.keys()) - -# def gen_matrices(): -# for index in indices: -# if index <= 1: -# yield index, np.vstack(tuple(gen_matrices(index))) -# else: -# yield index, scipy.sparse.vstack(tuple(gen_matrices(index))) - -# return dict(gen_matrices()) - -# def get_value(self, variable, value): -# variable_indices = get_variable_indices_from_variable(self.state, variable)[1] - -# def gen_value_index(): -# for variable_index in variable_indices: -# try: -# yield self.variable_mapping.index(variable_index) -# except ValueError: -# raise ValueError(f'{variable_index} not found in {self.variable_mapping}') - -# value_index = list(gen_value_index()) - -# return value[value_index] - -# def set_value(self, variable, value): -# variable_indices = get_variable_indices_from_variable(self.state, variable)[1] -# value_index = list(self.variable_mapping.index(variable_index) for variable_index in variable_indices) -# vec = np.zeros(len(self.variable_mapping)) -# vec[value_index] = value -# return vec - -# # def get_matrix(self, eq_idx: int): -# # equations = self.data[eq_idx].data - -# def get_func(self, eq_idx: int): -# equations = self.data[eq_idx].data -# max_idx = max(equations.keys()) - -# if 2 <= max_idx: -# def func(x: np.ndarray) -> np.ndarray: -# if isinstance(x, tuple) or isinstance(x, list): -# x = np.array(x).reshape(-1, 1) - -# elif x.shape[0] == 1: -# x = x.reshape(-1, 1) - -# def acc_x_powers(acc, _): -# next = (acc @ x.T).reshape(-1, 1) -# return next - -# x_powers = tuple(itertools.accumulate( -# range(max_idx - 1), -# acc_x_powers, -# initial=x, -# ))[1:] - -# def gen_value(): -# for idx, equation in equations.items(): -# if idx == 0: -# yield equation - -# elif idx == 1: -# yield equation @ x - -# else: -# yield equation @ x_powers[idx-2] - -# return sum(gen_value()) - -# else: -# def func(x: np.ndarray) -> np.ndarray: -# if isinstance(x, tuple) or isinstance(x, list): -# x = np.array(x).reshape(-1, 1) - -# def gen_value(): -# for idx, equation in equations.items(): -# if idx == 0: -# yield equation - -# else: -# yield equation @ x - -# return sum(gen_value()) - -# return func - -# def to_matrix_repr( -# expressions: Expression | tuple[Expression], -# variables: Expression = None, -# sorted: bool = None, -# ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: - -# if isinstance(expressions, Expression): -# expressions = (expressions,) - -# assert isinstance(variables, ExpressionBaseMixin) or variables is None, f'{variables=}' - -# def func(state: ExpressionState): - -# def acc_underlying_application(acc, v): -# state, underlying_list = acc - -# state, underlying = v.apply(state) - -# assert underlying.shape[1] == 1, f'{underlying.shape[1]=} is not 1' - -# return state, underlying_list + (underlying,) - -# *_, (state, polymatrix_list) = tuple(itertools.accumulate( -# expressions, -# acc_underlying_application, -# initial=(state, tuple()), -# )) - -# if variables is None: -# sorted_variable_index = tuple() - -# else: -# state, variable_index = get_variable_indices_from_variable(state, variables) - -# if sorted: -# tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) - -# sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) - -# else: -# sorted_variable_index = variable_index - -# sorted_variable_index_set = set(sorted_variable_index) -# if len(sorted_variable_index) != len(sorted_variable_index_set): -# duplicates = tuple(state.get_name_from_offset(var) for var in sorted_variable_index_set if 1 < sorted_variable_index.count(var)) -# raise Exception(f'{duplicates=}. Make sure you give a unique name for each variables.') - -# variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)} - -# n_param = len(sorted_variable_index) - -# def gen_numpy_matrices(): -# for polymatrix in polymatrix_list: - -# n_row = polymatrix.shape[0] - -# buffer = MatrixBuffer( -# data={}, -# n_row=n_row, -# n_param=n_param, -# ) - -# for row in range(n_row): -# polymatrix_terms = polymatrix.get_poly(row, 0) - -# if polymatrix_terms is None: -# continue - -# if len(polymatrix_terms) == 0: -# buffer.add(row, 0, 0, 0) - -# else: -# for monomial, value in polymatrix_terms.items(): - -# def gen_new_monomial(): -# for var, count in monomial: -# try: -# index = variable_index_map[var] -# except KeyError: -# raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}') - -# for _ in range(count): -# yield index - -# new_monomial = tuple(gen_new_monomial()) - -# cols = monomial_to_index(n_param, new_monomial) - -# col_value = value / len(cols) - -# for col in cols: -# degree = sum(count for _, count in monomial) -# buffer.add(row, col, degree, col_value) - -# yield buffer - -# underlying_matrices = tuple(gen_numpy_matrices()) - -# def gen_auxillary_equations(): -# for key, monomial_terms in state.auxillary_equations.items(): -# if key in sorted_variable_index: -# yield key, monomial_terms - -# auxillary_equations = tuple(gen_auxillary_equations()) - -# n_row = len(auxillary_equations) - -# if n_row == 0: -# auxillary_matrix_equations = None - -# else: -# buffer = MatrixBuffer( -# data={}, -# n_row=n_row, -# n_param=n_param, -# ) - -# for row, (key, monomial_terms) in enumerate(auxillary_equations): -# for monomial, value in monomial_terms.items(): -# new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count)) - -# cols = monomial_to_index(n_param, new_monomial) - -# col_value = value / len(cols) - -# for col in cols: -# buffer.add(row, col, sum(count for _, count in monomial), col_value) - -# auxillary_matrix_equations = buffer - -# result = MatrixRepresentations( -# data=underlying_matrices, -# aux_data=auxillary_matrix_equations, -# variable_mapping=sorted_variable_index, -# state=state, -# ) - -# return state, result - -# return init_state_monad(func) - - -# def to_constant_repr( -# expr: Expression, -# assert_constant: bool = True, -# ) -> StateMonadMixin[ExpressionState, np.ndarray]: - -# def func(state: ExpressionState): -# state, underlying = expr.apply(state) - -# A = np.zeros(underlying.shape, dtype=np.double) - -# for (row, col), polynomial in underlying.gen_terms(): -# for monomial, value in polynomial.items(): -# if len(monomial) == 0: -# A[row, col] = value - -# elif assert_constant: -# raise Exception(f'non-constant term {monomial=}') - -# return state, A - -# return init_state_monad(func) - - -# def degrees( -# expr: Expression, -# variables: Expression, -# ) -> StateMonadMixin[ExpressionState, np.ndarray]: - -# def func(state: ExpressionState): -# state, underlying = expr.apply(state) -# state, variable_indices = get_variable_indices_from_variable(state, variables) - -# def gen_rows(): -# for row in range(underlying.shape[0]): -# def gen_cols(): -# for col in range(underlying.shape[1]): - -# def gen_degrees(): -# polynomial = underlying.get_poly(row, col) - -# if polynomial is None: -# yield 0 - -# else: -# for monomial, _ in polynomial.items(): -# yield sum(count for var, count in monomial if var in variable_indices) - -# yield tuple(set(gen_degrees())) - -# yield tuple(gen_cols()) - -# return state, tuple(gen_rows()) - -# return init_state_monad(func) - - -# def to_sympy_repr( -# expr: Expression, -# ) -> StateMonadMixin[ExpressionState, sympy.Expr]: - -# def func(state: ExpressionState): -# state, underlying = expr.apply(state) - -# A = np.zeros(underlying.shape, dtype=object) - -# for (row, col), polynomial in underlying.gen_terms(): - -# sympy_polynomial = 0 - -# for monomial, value in polynomial.items(): -# sympy_monomial = 1 - -# for offset, count in monomial: - -# variable = state.get_key_from_offset(offset) -# # def get_variable_from_offset(offset: int): -# # for variable, (start, end) in state.offset_dict.items(): -# # if start <= offset < end: -# # assert end - start == 1, f'{start=}, {end=}, {variable=}' - -# if isinstance(variable, sympy.core.symbol.Symbol): -# variable_name = variable.name -# elif isinstance(variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)): -# variable_name = variable.name -# elif isinstance(variable, str): -# variable_name = variable -# else: -# raise Exception(f'{variable=}') - -# start, end = state.offset_dict[variable] - -# if end - start == 1: -# sympy_var = sympy.Symbol(variable_name) -# else: -# sympy_var = sympy.Symbol(f'{variable_name}_{offset - start + 1}') - -# # var = get_variable_from_offset(offset) -# sympy_monomial *= sympy_var**count - -# sympy_polynomial += value * sympy_monomial - -# A[row, col] = sympy_polynomial - -# return state, A - -# return init_state_monad(func) diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py index e44228f..7bafb5d 100644 --- a/polymatrix/denserepr/from_.py +++ b/polymatrix/denserepr/from_.py @@ -9,7 +9,7 @@ from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.statemonad.init import init_state_monad from polymatrix.statemonad.mixins import StateMonadMixin -from polymatrix.expression.utils.monomialtoindex import monomial_to_index +from polymatrix.denserepr.utils.monomialtoindex import variable_indices_to_column_index def from_polymatrix( @@ -97,9 +97,9 @@ def from_polymatrix( for _ in range(count): yield index - new_monomial = tuple(gen_new_monomial()) + new_variable_indices = tuple(gen_new_monomial()) - cols = monomial_to_index(n_param, new_monomial) + cols = variable_indices_to_column_index(n_param, new_variable_indices) col_value = value / len(cols) @@ -134,7 +134,7 @@ def from_polymatrix( for monomial, value in monomial_terms.items(): new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count)) - cols = monomial_to_index(n_param, new_monomial) + cols = variable_indices_to_column_index(n_param, new_monomial) col_value = value / len(cols) diff --git a/polymatrix/denserepr/utils/monomialtoindex.py b/polymatrix/denserepr/utils/monomialtoindex.py new file mode 100644 index 0000000..c13adfd --- /dev/null +++ b/polymatrix/denserepr/utils/monomialtoindex.py @@ -0,0 +1,11 @@ +import itertools + + +def variable_indices_to_column_index( + n_var: int, + variable_indices: tuple[int, ...], +) -> int: + + variable_indices_perm = itertools.permutations(variable_indices) + + return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in variable_indices_perm) diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index da669f4..9da69af 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -1,13 +1,12 @@ import abc import itertools -from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial +from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.multiplymonomials import multiply_monomials class CombinationsExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 1728a2d..38ea9cc 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials +from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.tooperatorexception import to_operator_exception @@ -60,23 +60,23 @@ class DerivativeExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - underlying_terms = underlying.get_poly(row, 0) - if underlying_terms is None: + polynomial = underlying.get_poly(row, 0) + if polynomial is None: continue # derivate each variable and map result to the corresponding column for col, diff_wrt_variable in enumerate(diff_wrt_variables): - state, derivation_terms = get_derivative_monomials( - monomial_terms=underlying_terms, + state, derivation = differentiate_polynomial( + polynomial=polynomial, diff_wrt_variable=diff_wrt_variable, state=state, considered_variables=set(diff_wrt_variables), introduce_derivatives=self.introduce_derivatives, ) - if 0 < len(derivation_terms): - terms[row, col] = derivation_terms + if 0 < len(derivation): + terms[row, col] = derivation poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 29d1a0a..10db5b4 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials +from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable @@ -38,19 +38,20 @@ class DivergenceExprMixin(ExpressionBaseMixin): for row, variable in enumerate(variables): - underlying_terms = underlying.get_poly(row, 0) - if underlying_terms is None: + polynomial = underlying.get_poly(row, 0) + + if polynomial is None: continue - state, derivation_terms = get_derivative_monomials( - monomial_terms=underlying_terms, + state, derivation = differentiate_polynomial( + polynomial=polynomial, diff_wrt_variable=variable, state=state, considered_variables=set(), introduce_derivatives=False, ) - for monomial, value in derivation_terms.items(): + for monomial, value in derivation.items(): monomial_terms[monomial] += value poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index e6e64b1..17be163 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices +from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices class ElemMultExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index 516a7a4..ed9e79f 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins import ExpressionStateMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable -from polymatrix.expression.utils.sortmonomials import sort_monomials +from polymatrix.polymatrix.utils.sortmonomials import sort_monomials class LinearMonomialsExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index 3343a65..b7ae5ce 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial +from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.tooperatorexception import to_operator_exception diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index 45be74b..0168e27 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -1,14 +1,12 @@ import abc import itertools -from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial +from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary -from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -from polymatrix.expression.utils.multiplymonomials import multiply_monomials from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index 19c1610..e739364 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable -from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices +from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.tooperatorexception import to_operator_exception diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index ff16275..99befc8 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins import ExpressionStateMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable -from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices +from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices class QuadraticMonomialsExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index e1c0c3d..b75a2de 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable -from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial +from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial class SubstituteExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 37add0a..9275332 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -7,8 +7,8 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expressionstate.mixins import ExpressionStateMixin from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.expression.utils.getmonomialindices import get_monomial_indices -from polymatrix.expression.utils.sortmonomials import sort_monomials -from polymatrix.expression.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices +from polymatrix.polymatrix.utils.sortmonomials import sort_monomials +from polymatrix.polymatrix.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices class SubtractMonomialsExprMixin(ExpressionBaseMixin): diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py index a8b2e16..83083aa 100644 --- a/polymatrix/expression/utils/getderivativemonomials.py +++ b/polymatrix/expression/utils/getderivativemonomials.py @@ -3,10 +3,11 @@ import dataclasses import itertools from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.polymatrix.typing import PolynomialData -def get_derivative_monomials( - monomial_terms, +def differentiate_polynomial( + polynomial: PolynomialData, diff_wrt_variable: int, state: ExpressionState, considered_variables: set, @@ -21,7 +22,7 @@ def get_derivative_monomials( if introduce_derivatives: def gen_new_variables(): - for monomial in monomial_terms.keys(): + for monomial in polynomial.keys(): for var in monomial: if var is not diff_wrt_variable and var not in considered_variables: yield var @@ -42,8 +43,8 @@ def get_derivative_monomials( state = state.register(key=key, n_param=1) # for each new variable we expect an auxillary equation - state, auxillary_derivation_terms = get_derivative_monomials( - monomial_terms=state.auxillary_equations[new_variable], + state, auxillary_derivation_terms = differentiate_polynomial( + polynomial=state.auxillary_equations[new_variable], diff_wrt_variable=diff_wrt_variable, state=state, considered_variables=new_considered_variables, @@ -75,7 +76,7 @@ def get_derivative_monomials( derivation_terms = collections.defaultdict(float) - for monomial, value in monomial_terms.items(): + for monomial, value in polynomial.items(): monomial_cnt = dict(monomial) diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py index 42f9ca7..dcb7932 100644 --- a/polymatrix/expression/utils/getmonomialindices.py +++ b/polymatrix/expression/utils/getmonomialindices.py @@ -2,7 +2,10 @@ from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -def get_monomial_indices(state: ExpressionState, monomials: ExpressionBaseMixin) -> tuple[ExpressionState, tuple[int, ...]]: +def get_monomial_indices( + state: ExpressionState, + monomials: ExpressionBaseMixin, + ) -> tuple[ExpressionState, tuple[int, ...]]: state, monomials_obj = monomials.apply(state) diff --git a/polymatrix/expression/utils/gettupleremainder.py b/polymatrix/expression/utils/gettupleremainder.py deleted file mode 100644 index 2a22ead..0000000 --- a/polymatrix/expression/utils/gettupleremainder.py +++ /dev/null @@ -1,9 +0,0 @@ -def get_tuple_remainder(l1, l2): - # create a copy of l1 - l1_copy = list(l1) - - # remove from the copy all elements of l2 - for e in l2: - l1_copy.remove(e) - - return tuple(l1_copy) \ No newline at end of file diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 4ee7037..d61da57 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -1,9 +1,14 @@ import itertools +import typing +from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -def get_variable_indices_from_variable(state, variable) -> tuple[int] | None: +def get_variable_indices_from_variable( + state: ExpressionState, + variable: ExpressionBaseMixin | int | typing.Any, +) -> tuple[int, ...] | None: if isinstance(variable, ExpressionBaseMixin): state, variable_polynomial = variable.apply(state) @@ -40,13 +45,12 @@ def get_variable_indices_from_variable(state, variable) -> tuple[int] | None: return state, variable_indices +# not used, remove? def get_variable_indices(state, variables): if not isinstance(variables, tuple): variables = (variables,) - # assert isinstance(variables, tuple), f'{variables=}' - def acc_variable_indices(acc, variable): state, indices = acc @@ -61,43 +65,3 @@ def get_variable_indices(state, variables): ) return state, indices - - - # global_state = [state] - - # assert isinstance(variables, tuple), f'{variables=}' - - # # if not isinstance(variables, tuple): - # # variables = (variables,) - - # def gen_indices(): - # for variable in variables: - # if isinstance(variable, ExpressionBaseMixin): - # global_state[0], variable_polynomial = variable.apply(global_state[0]) - - # assert variable_polynomial.shape[1] == 1 - - # for row in range(variable_polynomial.shape[0]): - # row_terms = variable_polynomial.get_poly(row, 0) - - # assert len(row_terms) == 1, f'{row_terms} contains more than one term' - - # for monomial in row_terms.keys(): - # assert len(monomial) <= 1, f'{monomial=} contains more than one variable' - - # if len(monomial) == 0: - # continue - - # assert monomial[0][1] == 1, f'{monomial[0]=}' - # yield monomial[0][0] - - # elif isinstance(variable, int): - # yield variable - - # # else: - # elif variable in global_state[0].offset_dict: - # yield global_state[0].offset_dict[variable][0] - - # indices = tuple(gen_indices()) - - # return global_state[0], indices diff --git a/polymatrix/expression/utils/mergemonomialindices.py b/polymatrix/expression/utils/mergemonomialindices.py deleted file mode 100644 index c572852..0000000 --- a/polymatrix/expression/utils/mergemonomialindices.py +++ /dev/null @@ -1,22 +0,0 @@ -from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices - - -def merge_monomial_indices(monomials): - if len(monomials) == 0: - return tuple() - - elif len(monomials) == 1: - return monomials[0] - - else: - - m1_dict = dict(monomials[0]) - - for other in monomials[1:]: - for index, count in other: - if index in m1_dict: - m1_dict[index] += count - else: - m1_dict[index] = count - - return sort_monomial_indices(m1_dict.items()) diff --git a/polymatrix/expression/utils/monomialtoindex.py b/polymatrix/expression/utils/monomialtoindex.py deleted file mode 100644 index bb1e570..0000000 --- a/polymatrix/expression/utils/monomialtoindex.py +++ /dev/null @@ -1,7 +0,0 @@ -import itertools - - -def monomial_to_index(n_var, monomial): - monomial_perm = itertools.permutations(monomial) - - return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in monomial_perm) diff --git a/polymatrix/expression/utils/multiplymonomials.py b/polymatrix/expression/utils/multiplymonomials.py deleted file mode 100644 index a4b3273..0000000 --- a/polymatrix/expression/utils/multiplymonomials.py +++ /dev/null @@ -1,15 +0,0 @@ -import itertools -import math - -from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices - - -def multiply_monomials(left_monomials, right_monomials): - def gen_monomials(): - for left_monomial, right_monomial in itertools.product(left_monomials, right_monomials): - - monomial = merge_monomial_indices((left_monomial, right_monomial)) - - yield monomial - - return gen_monomials() diff --git a/polymatrix/expression/utils/multiplypolynomial.py b/polymatrix/expression/utils/multiplypolynomial.py deleted file mode 100644 index e27a124..0000000 --- a/polymatrix/expression/utils/multiplypolynomial.py +++ /dev/null @@ -1,28 +0,0 @@ -import itertools -import math - -from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices -from polymatrix.polymatrix.typing import PolynomialData - -def multiply_polynomial(left: PolynomialData, right: PolynomialData, result: PolynomialData): - """ - Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`. - """ - - for (left_monomial, left_value), (right_monomial, right_value) \ - in itertools.product(left.items(), right.items()): - - value = left_value * right_value - - if math.isclose(value, 0, abs_tol=1e-12): - continue - - monomial = merge_monomial_indices((left_monomial, right_monomial)) - - if monomial not in result: - result[monomial] = 0 - - result[monomial] += value - - if math.isclose(result[monomial], 0, abs_tol=1e-12): - del result[monomial] diff --git a/polymatrix/expression/utils/sortmonomialindices.py b/polymatrix/expression/utils/sortmonomialindices.py deleted file mode 100644 index 3d4734e..0000000 --- a/polymatrix/expression/utils/sortmonomialindices.py +++ /dev/null @@ -1,5 +0,0 @@ -def sort_monomial_indices(monomial): - return tuple(sorted( - monomial, - key=lambda m: m[0], - )) \ No newline at end of file diff --git a/polymatrix/expression/utils/sortmonomials.py b/polymatrix/expression/utils/sortmonomials.py deleted file mode 100644 index 46e7118..0000000 --- a/polymatrix/expression/utils/sortmonomials.py +++ /dev/null @@ -1,5 +0,0 @@ -def sort_monomials(monomials): - return tuple(sorted( - monomials, - key=lambda m: (sum(count for _, count in m), len(m), m), - )) \ No newline at end of file diff --git a/polymatrix/expression/utils/splitmonomialindices.py b/polymatrix/expression/utils/splitmonomialindices.py deleted file mode 100644 index 4ffa6ff..0000000 --- a/polymatrix/expression/utils/splitmonomialindices.py +++ /dev/null @@ -1,24 +0,0 @@ -def split_monomial_indices(monomial): - left = [] - right = [] - - is_left = True - - for idx, count in monomial: - count_left = count // 2 - - if count % 2: - if is_left: - count_left = count_left + 1 - - is_left = not is_left - - count_right = count - count_left - - if 0 < count_left: - left.append((idx, count_left)) - - if 0 < count_right: - right.append((idx, count - count_left)) - - return tuple(left), tuple(right) \ No newline at end of file diff --git a/polymatrix/expression/utils/subtractmonomialindices.py b/polymatrix/expression/utils/subtractmonomialindices.py deleted file mode 100644 index 236ba9c..0000000 --- a/polymatrix/expression/utils/subtractmonomialindices.py +++ /dev/null @@ -1,23 +0,0 @@ -from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices - - -class SubtractError(Exception): - pass - - -def subtract_monomial_indices(m1, m2): - m1_dict = dict(m1) - - for index, count in m2: - if index not in m1_dict: - raise SubtractError() - - m1_dict[index] -= count - - if m1_dict[index] == 0: - del m1_dict[index] - - elif m1_dict[index] < 0: - raise SubtractError() - - return sort_monomial_indices(m1_dict.items()) diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py index fe5946e..f957685 100644 --- a/polymatrix/polymatrix/impl.py +++ b/polymatrix/polymatrix/impl.py @@ -1,13 +1,14 @@ import dataclassabc from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin +from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin +from polymatrix.polymatrix.typing import PolynomialData @dataclassabc.dataclassabc(frozen=True) class PolyMatrixImpl(PolyMatrix): - terms: dict - shape: tuple[int, ...] + terms: dict[tuple[int, int], PolynomialData] + shape: tuple[int, int] @dataclassabc.dataclassabc(frozen=True) diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 7f0b60f..2dc2a63 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -1,9 +1,10 @@ from polymatrix.polymatrix.impl import PolyMatrixImpl +from polymatrix.polymatrix.typing import PolynomialData def init_poly_matrix( - terms: dict, - shape: tuple, + terms: dict[tuple[int, int], PolynomialData], + shape: tuple[int, int], ): return PolyMatrixImpl( diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index 1aa0f19..5046492 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -1,6 +1,8 @@ import abc import typing +from polymatrix.polymatrix.typing import PolynomialData + class PolyMatrixMixin(abc.ABC): @property @@ -8,7 +10,7 @@ class PolyMatrixMixin(abc.ABC): def shape(self) -> tuple[int, int]: ... - def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], dict[tuple[int, ...], float]], None, None]: + def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]: for row in range(self.shape[0]): for col in range(self.shape[1]): polynomial = self.get_poly(row, col) @@ -18,7 +20,7 @@ class PolyMatrixMixin(abc.ABC): yield (row, col), polynomial @abc.abstractclassmethod - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None: + def get_poly(self, row: int, col: int) -> PolynomialData | None: ... @@ -28,11 +30,11 @@ class PolyMatrixAsDictMixin( ): @property @abc.abstractmethod - def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]: + def terms(self) -> dict[tuple[int, int], PolynomialData]: ... # overwrites the abstract method of `PolyMatrixMixin` - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None: + def get_poly(self, row: int, col: int) -> PolynomialData | None: if (row, col) in self.terms: return self.terms[row, col] @@ -43,9 +45,9 @@ class BroadcastPolyMatrixMixin( ): @property @abc.abstractmethod - def polynomial(self) -> dict[tuple[int, ...], float]: + def polynomial(self) -> PolynomialData: ... # overwrites the abstract method of `PolyMatrixMixin` - def get_poly(self, col: int, row: int) -> dict[tuple[int, ...], float] | None: + def get_poly(self, col: int, row: int) -> PolynomialData | None: return self.polynomial diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py index f8031e5..0540054 100644 --- a/polymatrix/polymatrix/typing.py +++ b/polymatrix/polymatrix/typing.py @@ -1,4 +1,8 @@ +# monomial x1**2 x2 +# with indices {x1: 0, x2: 1} +# is represented as ((0, 2), (1, 1)) +MonomialData = tuple[tuple[int, int], ...] -PolynomialData = dict[tuple[int, ...], float] -PolynomialMatrixData = dict[tuple[int, int], dict[tuple[int, ...], float]] \ No newline at end of file +PolynomialData = dict[MonomialData, float] +PolynomialMatrixData = dict[tuple[int, int], dict[MonomialData, float]] \ No newline at end of file diff --git a/polymatrix/polymatrix/utils/mergemonomialindices.py b/polymatrix/polymatrix/utils/mergemonomialindices.py new file mode 100644 index 0000000..8285c98 --- /dev/null +++ b/polymatrix/polymatrix/utils/mergemonomialindices.py @@ -0,0 +1,38 @@ +from polymatrix.polymatrix.typing import MonomialData +from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices + + +def merge_monomial_indices( + monomials: tuple[MonomialData, ...], +) -> MonomialData: + """ + (x1**2 x2, x2**2) -> x1**2 x2**3 + + or in terms of indices {x1: 0, x2: 1}: + + ( + ((0, 2), (1, 1)), # x1**2 x2 + ((1, 2),) # x2**2 + ) -> ((0, 2), (1, 3)) # x1**1 x2**3 + """ + + if len(monomials) == 0: + return tuple() + + elif len(monomials) == 1: + return monomials[0] + + else: + + m1_dict = dict(monomials[0]) + + for other in monomials[1:]: + for index, count in other: + if index in m1_dict: + m1_dict[index] += count + else: + m1_dict[index] = count + + # sort monomials according to their index + # ((1, 3), (0, 2)) -> ((0, 2), (1, 3)) + return sort_monomial_indices(m1_dict.items()) diff --git a/polymatrix/polymatrix/utils/multiplypolynomial.py b/polymatrix/polymatrix/utils/multiplypolynomial.py new file mode 100644 index 0000000..17fd154 --- /dev/null +++ b/polymatrix/polymatrix/utils/multiplypolynomial.py @@ -0,0 +1,32 @@ +import itertools +import math + +from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices +from polymatrix.polymatrix.typing import PolynomialData + +def multiply_polynomial( + left: PolynomialData, + right: PolynomialData, + result: PolynomialData, + ) -> None: + """ + Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`. + """ + + for (left_monomial, left_value), (right_monomial, right_value) \ + in itertools.product(left.items(), right.items()): + + value = left_value * right_value + + if math.isclose(value, 0, abs_tol=1e-12): + continue + + monomial = merge_monomial_indices((left_monomial, right_monomial)) + + if monomial not in result: + result[monomial] = 0 + + result[monomial] += value + + if math.isclose(result[monomial], 0, abs_tol=1e-12): + del result[monomial] diff --git a/polymatrix/polymatrix/utils/sortmonomialindices.py b/polymatrix/polymatrix/utils/sortmonomialindices.py new file mode 100644 index 0000000..2a13bc3 --- /dev/null +++ b/polymatrix/polymatrix/utils/sortmonomialindices.py @@ -0,0 +1,10 @@ +from polymatrix.polymatrix.typing import MonomialData + + +def sort_monomial_indices( + monomial: MonomialData, +) -> MonomialData: + return tuple(sorted( + monomial, + key=lambda m: m[0], + )) diff --git a/polymatrix/polymatrix/utils/sortmonomials.py b/polymatrix/polymatrix/utils/sortmonomials.py new file mode 100644 index 0000000..e1c82e7 --- /dev/null +++ b/polymatrix/polymatrix/utils/sortmonomials.py @@ -0,0 +1,9 @@ +from polymatrix.polymatrix.typing import MonomialData + +def sort_monomials( + monomials: tuple[MonomialData], +) -> tuple[MonomialData]: + return tuple(sorted( + monomials, + key=lambda m: (sum(count for _, count in m), len(m), m), + )) diff --git a/polymatrix/polymatrix/utils/splitmonomialindices.py b/polymatrix/polymatrix/utils/splitmonomialindices.py new file mode 100644 index 0000000..8d5d148 --- /dev/null +++ b/polymatrix/polymatrix/utils/splitmonomialindices.py @@ -0,0 +1,29 @@ +from polymatrix.polymatrix.typing import MonomialData + + +def split_monomial_indices( + monomial: MonomialData, +) -> tuple[MonomialData, MonomialData]: + left = [] + right = [] + + is_left = True + + for idx, count in monomial: + count_left = count // 2 + + if count % 2: + if is_left: + count_left = count_left + 1 + + is_left = not is_left + + count_right = count - count_left + + if 0 < count_left: + left.append((idx, count_left)) + + if 0 < count_right: + right.append((idx, count - count_left)) + + return tuple(left), tuple(right) diff --git a/polymatrix/polymatrix/utils/subtractmonomialindices.py b/polymatrix/polymatrix/utils/subtractmonomialindices.py new file mode 100644 index 0000000..c62bf3e --- /dev/null +++ b/polymatrix/polymatrix/utils/subtractmonomialindices.py @@ -0,0 +1,28 @@ +from polymatrix.polymatrix.typing import MonomialData + +from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices + + +class SubtractError(Exception): + pass + + +def subtract_monomial_indices( + m1: MonomialData, + m2: MonomialData, +) -> MonomialData: + m1_dict = dict(m1) + + for index, count in m2: + if index not in m1_dict: + raise SubtractError() + + m1_dict[index] -= count + + if m1_dict[index] == 0: + del m1_dict[index] + + elif m1_dict[index] < 0: + raise SubtractError() + + return sort_monomial_indices(m1_dict.items()) -- cgit v1.2.1