summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2024-02-27 07:39:52 +0100
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2024-02-27 07:39:52 +0100
commit6a9ca2e0c327281c1359b6152f6afc9165e6bf9b (patch)
tree1630c0c5a460ba4f9f207eafb8809da6f5619941
parentreplace indentation tabs with spaces (diff)
downloadpolymatrix-6a9ca2e0c327281c1359b6152f6afc9165e6bf9b.tar.gz
polymatrix-6a9ca2e0c327281c1359b6152f6afc9165e6bf9b.zip
improve typing with PolynomialData and MonomialData
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py519
-rw-r--r--polymatrix/denserepr/from_.py8
-rw-r--r--polymatrix/denserepr/utils/monomialtoindex.py11
-rw-r--r--polymatrix/expression/mixins/combinationsexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/divergenceexprmixin.py13
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearmonomialsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/productexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/quadraticmonomialsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/subtractmonomialsexprmixin.py4
-rw-r--r--polymatrix/expression/utils/getderivativemonomials.py13
-rw-r--r--polymatrix/expression/utils/getmonomialindices.py5
-rw-r--r--polymatrix/expression/utils/gettupleremainder.py9
-rw-r--r--polymatrix/expression/utils/getvariableindices.py50
-rw-r--r--polymatrix/expression/utils/mergemonomialindices.py22
-rw-r--r--polymatrix/expression/utils/monomialtoindex.py7
-rw-r--r--polymatrix/expression/utils/multiplymonomials.py15
-rw-r--r--polymatrix/expression/utils/sortmonomialindices.py5
-rw-r--r--polymatrix/expression/utils/sortmonomials.py5
-rw-r--r--polymatrix/polymatrix/impl.py7
-rw-r--r--polymatrix/polymatrix/init.py5
-rw-r--r--polymatrix/polymatrix/mixins.py14
-rw-r--r--polymatrix/polymatrix/typing.py8
-rw-r--r--polymatrix/polymatrix/utils/mergemonomialindices.py38
-rw-r--r--polymatrix/polymatrix/utils/multiplypolynomial.py (renamed from polymatrix/expression/utils/multiplypolynomial.py)8
-rw-r--r--polymatrix/polymatrix/utils/sortmonomialindices.py10
-rw-r--r--polymatrix/polymatrix/utils/sortmonomials.py9
-rw-r--r--polymatrix/polymatrix/utils/splitmonomialindices.py (renamed from polymatrix/expression/utils/splitmonomialindices.py)9
-rw-r--r--polymatrix/polymatrix/utils/subtractmonomialindices.py (renamed from polymatrix/expression/utils/subtractmonomialindices.py)9
33 files changed, 156 insertions, 680 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 6592238..45968aa 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -2,541 +2,24 @@ from polymatrix.expressionstate.abc import ExpressionState as internal_Expressio
from polymatrix.expressionstate.init import init_expression_state as internal_init_expression_state
from polymatrix.expression.expression import Expression as internal_Expression
from polymatrix.expression.from_ import from_ as internal_from
-# from polymatrix.expression.from_ import from_sympy as internal_from_sympy
from polymatrix.expression import v_stack as internal_v_stack
from polymatrix.expression import h_stack as internal_h_stack
from polymatrix.expression import product as internal_product
-from polymatrix.denserepr.from_ import from_polymatrix
-# from polymatrix.expression.to import shape as internal_shape
from polymatrix.expression.to import to_constant as internal_to_constant
-# from polymatrix.expression.to import to_degrees as internal_to_degrees
from polymatrix.expression.to import to_sympy as internal_to_sympy
+from polymatrix.denserepr.from_ import from_polymatrix
Expression = internal_Expression
ExpressionState = internal_ExpressionState
init_expression_state = internal_init_expression_state
from_ = internal_from
-# from_sympy = internal_from_sympy
v_stack = internal_v_stack
h_stack = internal_h_stack
product = internal_product
-# to_shape = internal_shape
to_constant_repr = internal_to_constant
to_constant = internal_to_constant
-# to_degrees = internal_to_degrees
to_sympy_repr = internal_to_sympy
to_sympy = internal_to_sympy
to_matrix_repr = from_polymatrix
to_dense = from_polymatrix
-
-# def from_sympy(
-# data: tuple[tuple[float]],
-# ):
-# return init_expression(
-# init_from_sympy_expr(data)
-# )
-
-# def from_state_monad(
-# data: StateMonad,
-# ):
-# return init_expression(
-# data.flat_map(lambda inner_data: init_from_sympy_expr(inner_data)),
-# )
-
-# def from_polymatrix(
-# polymatrix: PolyMatrix,
-# ):
-# return init_expression(
-# init_from_terms_expr(polymatrix)
-# )
-
-# def from_(
-# data: str | tuple[tuple[float]],
-# ):
-# if isinstance(data, str):
-# return from_(1).parametrize(data)
-
-# return from_sympy(data)
-
-# def v_stack(
-# expressions: tuple[Expression],
-# ):
-# return init_expression(
-# init_v_stack_expr(expressions)
-# )
-
-# def h_stack(
-# expressions: tuple[Expression],
-# ):
-# return init_expression(
-# init_v_stack_expr(tuple(expr.T for expr in expressions))
-# ).T
-
-# def block_diag(
-# expressions: tuple[Expression],
-# ):
-# return init_expression(
-# init_block_diag_expr(expressions)
-# )
-
-# def eye(
-# variable: tuple[Expression],
-# ):
-# return init_expression(
-# init_eye_expr(variable=variable)
-# )
-
-# def kkt_equality(
-# equality: Expression,
-# variables: Expression,
-# name: str,
-# ):
-# equality_der = equality.diff(
-# variables,
-# introduce_derivatives=True,
-# )
-
-# nu = equality_der[0, :].T.parametrize(name)
-
-# return nu, equality_der @ nu
-
-# def kkt_inequality(
-# inequality: Expression,
-# variables: Expression,
-# name: str,
-# ):
-
-# inequality_der = inequality.diff(
-# variables,
-# introduce_derivatives=True,
-# )
-
-# nu = inequality_der[0, :].T.parametrize(f'lambda_{name}')
-# r_nu = inequality_der[0, :].T.parametrize(f'r_lambda_{name}')
-# r_ineq = inequality_der[0, :].T.parametrize(f'r_ineq_{name}')
-
-# prim_expr = inequality - r_ineq * r_ineq
-
-# sec_expr = nu - r_nu * r_nu
-
-# compl_expr = nu * inequality
-
-# cost_expr = inequality_der @ nu
-
-# return h_stack((nu, r_nu, r_ineq)), (cost_expr, prim_expr, sec_expr, compl_expr)
-
-
-# def rows(
-# expr: Expression,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-# def func(state: ExpressionState):
-# state, underlying = expr.apply(state)
-
-# def gen_row_terms():
-# for row in range(underlying.shape[0]):
-
-# terms = {}
-
-# for col in range(underlying.shape[1]):
-# polynomial = underlying.get_poly(row, col)
-# if polynomial == None:
-# continue
-
-# terms[0, col] = polynomial
-
-# yield init_expression(underlying=init_from_terms_expr(
-# terms=terms,
-# shape=(1, underlying.shape[1])
-# ))
-
-# row_terms = tuple(gen_row_terms())
-
-# return state, row_terms
-
-# return init_state_monad(func)
-
-
-# @dataclasses.dataclass
-# class MatrixBuffer:
-# data: dict[int, np.ndarray]
-# n_row: int
-# n_param: int
-
-# def get_max_degree(self):
-# return max(degree for degree in self.data.keys())
-
-# def add_buffer(self, index: int):
-# if index <= 1:
-# buffer = np.zeros((self.n_row, self.n_param**index), dtype=np.double)
-
-# else:
-# buffer = scipy.sparse.dok_array((self.n_row, self.n_param**index), dtype=np.double)
-
-# self.data[index] = buffer
-
-# def add(self, row: int, col: int, index: int, value: float):
-# if index not in self.data:
-# self.add_buffer(index)
-
-# self.data[index][row, col] = value
-
-# def __getitem__(self, key):
-# if key not in self.data:
-# self.add_buffer(key)
-
-# return self.data[key]
-
-
-# @dataclasses.dataclass
-# class MatrixRepresentations:
-# data: tuple[MatrixBuffer, ...]
-# aux_data: typing.Optional[MatrixBuffer]
-# variable_mapping: tuple[int, ...]
-# state: ExpressionState
-
-# def merge_matrix_equations(self):
-# def gen_matrices(index: int):
-# for equations in self.data:
-# if index < len(equations):
-# yield equations[index]
-
-# if index < len(self.aux_data):
-# yield self.aux_data[index]
-
-# indices = set(key for equations in self.data + (self.aux_data,) for key in equations.keys())
-
-# def gen_matrices():
-# for index in indices:
-# if index <= 1:
-# yield index, np.vstack(tuple(gen_matrices(index)))
-# else:
-# yield index, scipy.sparse.vstack(tuple(gen_matrices(index)))
-
-# return dict(gen_matrices())
-
-# def get_value(self, variable, value):
-# variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
-
-# def gen_value_index():
-# for variable_index in variable_indices:
-# try:
-# yield self.variable_mapping.index(variable_index)
-# except ValueError:
-# raise ValueError(f'{variable_index} not found in {self.variable_mapping}')
-
-# value_index = list(gen_value_index())
-
-# return value[value_index]
-
-# def set_value(self, variable, value):
-# variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
-# value_index = list(self.variable_mapping.index(variable_index) for variable_index in variable_indices)
-# vec = np.zeros(len(self.variable_mapping))
-# vec[value_index] = value
-# return vec
-
-# # def get_matrix(self, eq_idx: int):
-# # equations = self.data[eq_idx].data
-
-# def get_func(self, eq_idx: int):
-# equations = self.data[eq_idx].data
-# max_idx = max(equations.keys())
-
-# if 2 <= max_idx:
-# def func(x: np.ndarray) -> np.ndarray:
-# if isinstance(x, tuple) or isinstance(x, list):
-# x = np.array(x).reshape(-1, 1)
-
-# elif x.shape[0] == 1:
-# x = x.reshape(-1, 1)
-
-# def acc_x_powers(acc, _):
-# next = (acc @ x.T).reshape(-1, 1)
-# return next
-
-# x_powers = tuple(itertools.accumulate(
-# range(max_idx - 1),
-# acc_x_powers,
-# initial=x,
-# ))[1:]
-
-# def gen_value():
-# for idx, equation in equations.items():
-# if idx == 0:
-# yield equation
-
-# elif idx == 1:
-# yield equation @ x
-
-# else:
-# yield equation @ x_powers[idx-2]
-
-# return sum(gen_value())
-
-# else:
-# def func(x: np.ndarray) -> np.ndarray:
-# if isinstance(x, tuple) or isinstance(x, list):
-# x = np.array(x).reshape(-1, 1)
-
-# def gen_value():
-# for idx, equation in equations.items():
-# if idx == 0:
-# yield equation
-
-# else:
-# yield equation @ x
-
-# return sum(gen_value())
-
-# return func
-
-# def to_matrix_repr(
-# expressions: Expression | tuple[Expression],
-# variables: Expression = None,
-# sorted: bool = None,
-# ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
-
-# if isinstance(expressions, Expression):
-# expressions = (expressions,)
-
-# assert isinstance(variables, ExpressionBaseMixin) or variables is None, f'{variables=}'
-
-# def func(state: ExpressionState):
-
-# def acc_underlying_application(acc, v):
-# state, underlying_list = acc
-
-# state, underlying = v.apply(state)
-
-# assert underlying.shape[1] == 1, f'{underlying.shape[1]=} is not 1'
-
-# return state, underlying_list + (underlying,)
-
-# *_, (state, polymatrix_list) = tuple(itertools.accumulate(
-# expressions,
-# acc_underlying_application,
-# initial=(state, tuple()),
-# ))
-
-# if variables is None:
-# sorted_variable_index = tuple()
-
-# else:
-# state, variable_index = get_variable_indices_from_variable(state, variables)
-
-# if sorted:
-# tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
-
-# sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
-
-# else:
-# sorted_variable_index = variable_index
-
-# sorted_variable_index_set = set(sorted_variable_index)
-# if len(sorted_variable_index) != len(sorted_variable_index_set):
-# duplicates = tuple(state.get_name_from_offset(var) for var in sorted_variable_index_set if 1 < sorted_variable_index.count(var))
-# raise Exception(f'{duplicates=}. Make sure you give a unique name for each variables.')
-
-# variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)}
-
-# n_param = len(sorted_variable_index)
-
-# def gen_numpy_matrices():
-# for polymatrix in polymatrix_list:
-
-# n_row = polymatrix.shape[0]
-
-# buffer = MatrixBuffer(
-# data={},
-# n_row=n_row,
-# n_param=n_param,
-# )
-
-# for row in range(n_row):
-# polymatrix_terms = polymatrix.get_poly(row, 0)
-
-# if polymatrix_terms is None:
-# continue
-
-# if len(polymatrix_terms) == 0:
-# buffer.add(row, 0, 0, 0)
-
-# else:
-# for monomial, value in polymatrix_terms.items():
-
-# def gen_new_monomial():
-# for var, count in monomial:
-# try:
-# index = variable_index_map[var]
-# except KeyError:
-# raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
-
-# for _ in range(count):
-# yield index
-
-# new_monomial = tuple(gen_new_monomial())
-
-# cols = monomial_to_index(n_param, new_monomial)
-
-# col_value = value / len(cols)
-
-# for col in cols:
-# degree = sum(count for _, count in monomial)
-# buffer.add(row, col, degree, col_value)
-
-# yield buffer
-
-# underlying_matrices = tuple(gen_numpy_matrices())
-
-# def gen_auxillary_equations():
-# for key, monomial_terms in state.auxillary_equations.items():
-# if key in sorted_variable_index:
-# yield key, monomial_terms
-
-# auxillary_equations = tuple(gen_auxillary_equations())
-
-# n_row = len(auxillary_equations)
-
-# if n_row == 0:
-# auxillary_matrix_equations = None
-
-# else:
-# buffer = MatrixBuffer(
-# data={},
-# n_row=n_row,
-# n_param=n_param,
-# )
-
-# for row, (key, monomial_terms) in enumerate(auxillary_equations):
-# for monomial, value in monomial_terms.items():
-# new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count))
-
-# cols = monomial_to_index(n_param, new_monomial)
-
-# col_value = value / len(cols)
-
-# for col in cols:
-# buffer.add(row, col, sum(count for _, count in monomial), col_value)
-
-# auxillary_matrix_equations = buffer
-
-# result = MatrixRepresentations(
-# data=underlying_matrices,
-# aux_data=auxillary_matrix_equations,
-# variable_mapping=sorted_variable_index,
-# state=state,
-# )
-
-# return state, result
-
-# return init_state_monad(func)
-
-
-# def to_constant_repr(
-# expr: Expression,
-# assert_constant: bool = True,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-# def func(state: ExpressionState):
-# state, underlying = expr.apply(state)
-
-# A = np.zeros(underlying.shape, dtype=np.double)
-
-# for (row, col), polynomial in underlying.gen_terms():
-# for monomial, value in polynomial.items():
-# if len(monomial) == 0:
-# A[row, col] = value
-
-# elif assert_constant:
-# raise Exception(f'non-constant term {monomial=}')
-
-# return state, A
-
-# return init_state_monad(func)
-
-
-# def degrees(
-# expr: Expression,
-# variables: Expression,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-# def func(state: ExpressionState):
-# state, underlying = expr.apply(state)
-# state, variable_indices = get_variable_indices_from_variable(state, variables)
-
-# def gen_rows():
-# for row in range(underlying.shape[0]):
-# def gen_cols():
-# for col in range(underlying.shape[1]):
-
-# def gen_degrees():
-# polynomial = underlying.get_poly(row, col)
-
-# if polynomial is None:
-# yield 0
-
-# else:
-# for monomial, _ in polynomial.items():
-# yield sum(count for var, count in monomial if var in variable_indices)
-
-# yield tuple(set(gen_degrees()))
-
-# yield tuple(gen_cols())
-
-# return state, tuple(gen_rows())
-
-# return init_state_monad(func)
-
-
-# def to_sympy_repr(
-# expr: Expression,
-# ) -> StateMonadMixin[ExpressionState, sympy.Expr]:
-
-# def func(state: ExpressionState):
-# state, underlying = expr.apply(state)
-
-# A = np.zeros(underlying.shape, dtype=object)
-
-# for (row, col), polynomial in underlying.gen_terms():
-
-# sympy_polynomial = 0
-
-# for monomial, value in polynomial.items():
-# sympy_monomial = 1
-
-# for offset, count in monomial:
-
-# variable = state.get_key_from_offset(offset)
-# # def get_variable_from_offset(offset: int):
-# # for variable, (start, end) in state.offset_dict.items():
-# # if start <= offset < end:
-# # assert end - start == 1, f'{start=}, {end=}, {variable=}'
-
-# if isinstance(variable, sympy.core.symbol.Symbol):
-# variable_name = variable.name
-# elif isinstance(variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)):
-# variable_name = variable.name
-# elif isinstance(variable, str):
-# variable_name = variable
-# else:
-# raise Exception(f'{variable=}')
-
-# start, end = state.offset_dict[variable]
-
-# if end - start == 1:
-# sympy_var = sympy.Symbol(variable_name)
-# else:
-# sympy_var = sympy.Symbol(f'{variable_name}_{offset - start + 1}')
-
-# # var = get_variable_from_offset(offset)
-# sympy_monomial *= sympy_var**count
-
-# sympy_polynomial += value * sympy_monomial
-
-# A[row, col] = sympy_polynomial
-
-# return state, A
-
-# return init_state_monad(func)
diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py
index e44228f..7bafb5d 100644
--- a/polymatrix/denserepr/from_.py
+++ b/polymatrix/denserepr/from_.py
@@ -9,7 +9,7 @@ from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.statemonad.init import init_state_monad
from polymatrix.statemonad.mixins import StateMonadMixin
-from polymatrix.expression.utils.monomialtoindex import monomial_to_index
+from polymatrix.denserepr.utils.monomialtoindex import variable_indices_to_column_index
def from_polymatrix(
@@ -97,9 +97,9 @@ def from_polymatrix(
for _ in range(count):
yield index
- new_monomial = tuple(gen_new_monomial())
+ new_variable_indices = tuple(gen_new_monomial())
- cols = monomial_to_index(n_param, new_monomial)
+ cols = variable_indices_to_column_index(n_param, new_variable_indices)
col_value = value / len(cols)
@@ -134,7 +134,7 @@ def from_polymatrix(
for monomial, value in monomial_terms.items():
new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count))
- cols = monomial_to_index(n_param, new_monomial)
+ cols = variable_indices_to_column_index(n_param, new_monomial)
col_value = value / len(cols)
diff --git a/polymatrix/denserepr/utils/monomialtoindex.py b/polymatrix/denserepr/utils/monomialtoindex.py
new file mode 100644
index 0000000..c13adfd
--- /dev/null
+++ b/polymatrix/denserepr/utils/monomialtoindex.py
@@ -0,0 +1,11 @@
+import itertools
+
+
+def variable_indices_to_column_index(
+ n_var: int,
+ variable_indices: tuple[int, ...],
+) -> int:
+
+ variable_indices_perm = itertools.permutations(variable_indices)
+
+ return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in variable_indices_perm)
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
index da669f4..9da69af 100644
--- a/polymatrix/expression/mixins/combinationsexprmixin.py
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -1,13 +1,12 @@
import abc
import itertools
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplymonomials import multiply_monomials
class CombinationsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 1728a2d..38ea9cc 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
+from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
@@ -60,23 +60,23 @@ class DerivativeExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- underlying_terms = underlying.get_poly(row, 0)
- if underlying_terms is None:
+ polynomial = underlying.get_poly(row, 0)
+ if polynomial is None:
continue
# derivate each variable and map result to the corresponding column
for col, diff_wrt_variable in enumerate(diff_wrt_variables):
- state, derivation_terms = get_derivative_monomials(
- monomial_terms=underlying_terms,
+ state, derivation = differentiate_polynomial(
+ polynomial=polynomial,
diff_wrt_variable=diff_wrt_variable,
state=state,
considered_variables=set(diff_wrt_variables),
introduce_derivatives=self.introduce_derivatives,
)
- if 0 < len(derivation_terms):
- terms[row, col] = derivation_terms
+ if 0 < len(derivation):
+ terms[row, col] = derivation
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 29d1a0a..10db5b4 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
+from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
@@ -38,19 +38,20 @@ class DivergenceExprMixin(ExpressionBaseMixin):
for row, variable in enumerate(variables):
- underlying_terms = underlying.get_poly(row, 0)
- if underlying_terms is None:
+ polynomial = underlying.get_poly(row, 0)
+
+ if polynomial is None:
continue
- state, derivation_terms = get_derivative_monomials(
- monomial_terms=underlying_terms,
+ state, derivation = differentiate_polynomial(
+ polynomial=polynomial,
diff_wrt_variable=variable,
state=state,
considered_variables=set(),
introduce_derivatives=False,
)
- for monomial, value in derivation_terms.items():
+ for monomial, value in derivation.items():
monomial_terms[monomial] += value
poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index e6e64b1..17be163 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
+from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices
class ElemMultExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index 516a7a4..ed9e79f 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.sortmonomials import sort_monomials
+from polymatrix.polymatrix.utils.sortmonomials import sort_monomials
class LinearMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index 3343a65..b7ae5ce 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
from polymatrix.utils.tooperatorexception import to_operator_exception
diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py
index 45be74b..0168e27 100644
--- a/polymatrix/expression/mixins/productexprmixin.py
+++ b/polymatrix/expression/mixins/productexprmixin.py
@@ -1,14 +1,12 @@
import abc
import itertools
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
from polymatrix.utils.getstacklines import FrameSummary
-from polymatrix.utils.tooperatorexception import to_operator_exception
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplymonomials import multiply_monomials
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index 19c1610..e739364 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -8,7 +8,7 @@ from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
+from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index ff16275..99befc8 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
+from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices
class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index e1c0c3d..b75a2de 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
class SubstituteExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
index 37add0a..9275332 100644
--- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
@@ -7,8 +7,8 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins import ExpressionStateMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
-from polymatrix.expression.utils.sortmonomials import sort_monomials
-from polymatrix.expression.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices
+from polymatrix.polymatrix.utils.sortmonomials import sort_monomials
+from polymatrix.polymatrix.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices
class SubtractMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py
index a8b2e16..83083aa 100644
--- a/polymatrix/expression/utils/getderivativemonomials.py
+++ b/polymatrix/expression/utils/getderivativemonomials.py
@@ -3,10 +3,11 @@ import dataclasses
import itertools
from polymatrix.expressionstate.abc import ExpressionState
+from polymatrix.polymatrix.typing import PolynomialData
-def get_derivative_monomials(
- monomial_terms,
+def differentiate_polynomial(
+ polynomial: PolynomialData,
diff_wrt_variable: int,
state: ExpressionState,
considered_variables: set,
@@ -21,7 +22,7 @@ def get_derivative_monomials(
if introduce_derivatives:
def gen_new_variables():
- for monomial in monomial_terms.keys():
+ for monomial in polynomial.keys():
for var in monomial:
if var is not diff_wrt_variable and var not in considered_variables:
yield var
@@ -42,8 +43,8 @@ def get_derivative_monomials(
state = state.register(key=key, n_param=1)
# for each new variable we expect an auxillary equation
- state, auxillary_derivation_terms = get_derivative_monomials(
- monomial_terms=state.auxillary_equations[new_variable],
+ state, auxillary_derivation_terms = differentiate_polynomial(
+ polynomial=state.auxillary_equations[new_variable],
diff_wrt_variable=diff_wrt_variable,
state=state,
considered_variables=new_considered_variables,
@@ -75,7 +76,7 @@ def get_derivative_monomials(
derivation_terms = collections.defaultdict(float)
- for monomial, value in monomial_terms.items():
+ for monomial, value in polynomial.items():
monomial_cnt = dict(monomial)
diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py
index 42f9ca7..dcb7932 100644
--- a/polymatrix/expression/utils/getmonomialindices.py
+++ b/polymatrix/expression/utils/getmonomialindices.py
@@ -2,7 +2,10 @@ from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-def get_monomial_indices(state: ExpressionState, monomials: ExpressionBaseMixin) -> tuple[ExpressionState, tuple[int, ...]]:
+def get_monomial_indices(
+ state: ExpressionState,
+ monomials: ExpressionBaseMixin,
+ ) -> tuple[ExpressionState, tuple[int, ...]]:
state, monomials_obj = monomials.apply(state)
diff --git a/polymatrix/expression/utils/gettupleremainder.py b/polymatrix/expression/utils/gettupleremainder.py
deleted file mode 100644
index 2a22ead..0000000
--- a/polymatrix/expression/utils/gettupleremainder.py
+++ /dev/null
@@ -1,9 +0,0 @@
-def get_tuple_remainder(l1, l2):
- # create a copy of l1
- l1_copy = list(l1)
-
- # remove from the copy all elements of l2
- for e in l2:
- l1_copy.remove(e)
-
- return tuple(l1_copy) \ No newline at end of file
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 4ee7037..d61da57 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,9 +1,14 @@
import itertools
+import typing
+from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
+def get_variable_indices_from_variable(
+ state: ExpressionState,
+ variable: ExpressionBaseMixin | int | typing.Any,
+) -> tuple[int, ...] | None:
if isinstance(variable, ExpressionBaseMixin):
state, variable_polynomial = variable.apply(state)
@@ -40,13 +45,12 @@ def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
return state, variable_indices
+# not used, remove?
def get_variable_indices(state, variables):
if not isinstance(variables, tuple):
variables = (variables,)
- # assert isinstance(variables, tuple), f'{variables=}'
-
def acc_variable_indices(acc, variable):
state, indices = acc
@@ -61,43 +65,3 @@ def get_variable_indices(state, variables):
)
return state, indices
-
-
- # global_state = [state]
-
- # assert isinstance(variables, tuple), f'{variables=}'
-
- # # if not isinstance(variables, tuple):
- # # variables = (variables,)
-
- # def gen_indices():
- # for variable in variables:
- # if isinstance(variable, ExpressionBaseMixin):
- # global_state[0], variable_polynomial = variable.apply(global_state[0])
-
- # assert variable_polynomial.shape[1] == 1
-
- # for row in range(variable_polynomial.shape[0]):
- # row_terms = variable_polynomial.get_poly(row, 0)
-
- # assert len(row_terms) == 1, f'{row_terms} contains more than one term'
-
- # for monomial in row_terms.keys():
- # assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
-
- # if len(monomial) == 0:
- # continue
-
- # assert monomial[0][1] == 1, f'{monomial[0]=}'
- # yield monomial[0][0]
-
- # elif isinstance(variable, int):
- # yield variable
-
- # # else:
- # elif variable in global_state[0].offset_dict:
- # yield global_state[0].offset_dict[variable][0]
-
- # indices = tuple(gen_indices())
-
- # return global_state[0], indices
diff --git a/polymatrix/expression/utils/mergemonomialindices.py b/polymatrix/expression/utils/mergemonomialindices.py
deleted file mode 100644
index c572852..0000000
--- a/polymatrix/expression/utils/mergemonomialindices.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices
-
-
-def merge_monomial_indices(monomials):
- if len(monomials) == 0:
- return tuple()
-
- elif len(monomials) == 1:
- return monomials[0]
-
- else:
-
- m1_dict = dict(monomials[0])
-
- for other in monomials[1:]:
- for index, count in other:
- if index in m1_dict:
- m1_dict[index] += count
- else:
- m1_dict[index] = count
-
- return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/expression/utils/monomialtoindex.py b/polymatrix/expression/utils/monomialtoindex.py
deleted file mode 100644
index bb1e570..0000000
--- a/polymatrix/expression/utils/monomialtoindex.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import itertools
-
-
-def monomial_to_index(n_var, monomial):
- monomial_perm = itertools.permutations(monomial)
-
- return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in monomial_perm)
diff --git a/polymatrix/expression/utils/multiplymonomials.py b/polymatrix/expression/utils/multiplymonomials.py
deleted file mode 100644
index a4b3273..0000000
--- a/polymatrix/expression/utils/multiplymonomials.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import itertools
-import math
-
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
-
-
-def multiply_monomials(left_monomials, right_monomials):
- def gen_monomials():
- for left_monomial, right_monomial in itertools.product(left_monomials, right_monomials):
-
- monomial = merge_monomial_indices((left_monomial, right_monomial))
-
- yield monomial
-
- return gen_monomials()
diff --git a/polymatrix/expression/utils/sortmonomialindices.py b/polymatrix/expression/utils/sortmonomialindices.py
deleted file mode 100644
index 3d4734e..0000000
--- a/polymatrix/expression/utils/sortmonomialindices.py
+++ /dev/null
@@ -1,5 +0,0 @@
-def sort_monomial_indices(monomial):
- return tuple(sorted(
- monomial,
- key=lambda m: m[0],
- )) \ No newline at end of file
diff --git a/polymatrix/expression/utils/sortmonomials.py b/polymatrix/expression/utils/sortmonomials.py
deleted file mode 100644
index 46e7118..0000000
--- a/polymatrix/expression/utils/sortmonomials.py
+++ /dev/null
@@ -1,5 +0,0 @@
-def sort_monomials(monomials):
- return tuple(sorted(
- monomials,
- key=lambda m: (sum(count for _, count in m), len(m), m),
- )) \ No newline at end of file
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index fe5946e..f957685 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -1,13 +1,14 @@
import dataclassabc
from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin
+from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin
+from polymatrix.polymatrix.typing import PolynomialData
@dataclassabc.dataclassabc(frozen=True)
class PolyMatrixImpl(PolyMatrix):
- terms: dict
- shape: tuple[int, ...]
+ terms: dict[tuple[int, int], PolynomialData]
+ shape: tuple[int, int]
@dataclassabc.dataclassabc(frozen=True)
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 7f0b60f..2dc2a63 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -1,9 +1,10 @@
from polymatrix.polymatrix.impl import PolyMatrixImpl
+from polymatrix.polymatrix.typing import PolynomialData
def init_poly_matrix(
- terms: dict,
- shape: tuple,
+ terms: dict[tuple[int, int], PolynomialData],
+ shape: tuple[int, int],
):
return PolyMatrixImpl(
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index 1aa0f19..5046492 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -1,6 +1,8 @@
import abc
import typing
+from polymatrix.polymatrix.typing import PolynomialData
+
class PolyMatrixMixin(abc.ABC):
@property
@@ -8,7 +10,7 @@ class PolyMatrixMixin(abc.ABC):
def shape(self) -> tuple[int, int]:
...
- def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], dict[tuple[int, ...], float]], None, None]:
+ def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]:
for row in range(self.shape[0]):
for col in range(self.shape[1]):
polynomial = self.get_poly(row, col)
@@ -18,7 +20,7 @@ class PolyMatrixMixin(abc.ABC):
yield (row, col), polynomial
@abc.abstractclassmethod
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
+ def get_poly(self, row: int, col: int) -> PolynomialData | None:
...
@@ -28,11 +30,11 @@ class PolyMatrixAsDictMixin(
):
@property
@abc.abstractmethod
- def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]:
+ def terms(self) -> dict[tuple[int, int], PolynomialData]:
...
# overwrites the abstract method of `PolyMatrixMixin`
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
+ def get_poly(self, row: int, col: int) -> PolynomialData | None:
if (row, col) in self.terms:
return self.terms[row, col]
@@ -43,9 +45,9 @@ class BroadcastPolyMatrixMixin(
):
@property
@abc.abstractmethod
- def polynomial(self) -> dict[tuple[int, ...], float]:
+ def polynomial(self) -> PolynomialData:
...
# overwrites the abstract method of `PolyMatrixMixin`
- def get_poly(self, col: int, row: int) -> dict[tuple[int, ...], float] | None:
+ def get_poly(self, col: int, row: int) -> PolynomialData | None:
return self.polynomial
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
index f8031e5..0540054 100644
--- a/polymatrix/polymatrix/typing.py
+++ b/polymatrix/polymatrix/typing.py
@@ -1,4 +1,8 @@
+# monomial x1**2 x2
+# with indices {x1: 0, x2: 1}
+# is represented as ((0, 2), (1, 1))
+MonomialData = tuple[tuple[int, int], ...]
-PolynomialData = dict[tuple[int, ...], float]
-PolynomialMatrixData = dict[tuple[int, int], dict[tuple[int, ...], float]] \ No newline at end of file
+PolynomialData = dict[MonomialData, float]
+PolynomialMatrixData = dict[tuple[int, int], dict[MonomialData, float]] \ No newline at end of file
diff --git a/polymatrix/polymatrix/utils/mergemonomialindices.py b/polymatrix/polymatrix/utils/mergemonomialindices.py
new file mode 100644
index 0000000..8285c98
--- /dev/null
+++ b/polymatrix/polymatrix/utils/mergemonomialindices.py
@@ -0,0 +1,38 @@
+from polymatrix.polymatrix.typing import MonomialData
+from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices
+
+
+def merge_monomial_indices(
+ monomials: tuple[MonomialData, ...],
+) -> MonomialData:
+ """
+ (x1**2 x2, x2**2) -> x1**2 x2**3
+
+ or in terms of indices {x1: 0, x2: 1}:
+
+ (
+ ((0, 2), (1, 1)), # x1**2 x2
+ ((1, 2),) # x2**2
+ ) -> ((0, 2), (1, 3)) # x1**1 x2**3
+ """
+
+ if len(monomials) == 0:
+ return tuple()
+
+ elif len(monomials) == 1:
+ return monomials[0]
+
+ else:
+
+ m1_dict = dict(monomials[0])
+
+ for other in monomials[1:]:
+ for index, count in other:
+ if index in m1_dict:
+ m1_dict[index] += count
+ else:
+ m1_dict[index] = count
+
+ # sort monomials according to their index
+ # ((1, 3), (0, 2)) -> ((0, 2), (1, 3))
+ return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/expression/utils/multiplypolynomial.py b/polymatrix/polymatrix/utils/multiplypolynomial.py
index e27a124..17fd154 100644
--- a/polymatrix/expression/utils/multiplypolynomial.py
+++ b/polymatrix/polymatrix/utils/multiplypolynomial.py
@@ -1,10 +1,14 @@
import itertools
import math
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
+from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices
from polymatrix.polymatrix.typing import PolynomialData
-def multiply_polynomial(left: PolynomialData, right: PolynomialData, result: PolynomialData):
+def multiply_polynomial(
+ left: PolynomialData,
+ right: PolynomialData,
+ result: PolynomialData,
+ ) -> None:
"""
Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`.
"""
diff --git a/polymatrix/polymatrix/utils/sortmonomialindices.py b/polymatrix/polymatrix/utils/sortmonomialindices.py
new file mode 100644
index 0000000..2a13bc3
--- /dev/null
+++ b/polymatrix/polymatrix/utils/sortmonomialindices.py
@@ -0,0 +1,10 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+
+def sort_monomial_indices(
+ monomial: MonomialData,
+) -> MonomialData:
+ return tuple(sorted(
+ monomial,
+ key=lambda m: m[0],
+ ))
diff --git a/polymatrix/polymatrix/utils/sortmonomials.py b/polymatrix/polymatrix/utils/sortmonomials.py
new file mode 100644
index 0000000..e1c82e7
--- /dev/null
+++ b/polymatrix/polymatrix/utils/sortmonomials.py
@@ -0,0 +1,9 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+def sort_monomials(
+ monomials: tuple[MonomialData],
+) -> tuple[MonomialData]:
+ return tuple(sorted(
+ monomials,
+ key=lambda m: (sum(count for _, count in m), len(m), m),
+ ))
diff --git a/polymatrix/expression/utils/splitmonomialindices.py b/polymatrix/polymatrix/utils/splitmonomialindices.py
index 4ffa6ff..8d5d148 100644
--- a/polymatrix/expression/utils/splitmonomialindices.py
+++ b/polymatrix/polymatrix/utils/splitmonomialindices.py
@@ -1,4 +1,9 @@
-def split_monomial_indices(monomial):
+from polymatrix.polymatrix.typing import MonomialData
+
+
+def split_monomial_indices(
+ monomial: MonomialData,
+) -> tuple[MonomialData, MonomialData]:
left = []
right = []
@@ -21,4 +26,4 @@ def split_monomial_indices(monomial):
if 0 < count_right:
right.append((idx, count - count_left))
- return tuple(left), tuple(right) \ No newline at end of file
+ return tuple(left), tuple(right)
diff --git a/polymatrix/expression/utils/subtractmonomialindices.py b/polymatrix/polymatrix/utils/subtractmonomialindices.py
index 236ba9c..c62bf3e 100644
--- a/polymatrix/expression/utils/subtractmonomialindices.py
+++ b/polymatrix/polymatrix/utils/subtractmonomialindices.py
@@ -1,11 +1,16 @@
-from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices
+from polymatrix.polymatrix.typing import MonomialData
+
+from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices
class SubtractError(Exception):
pass
-def subtract_monomial_indices(m1, m2):
+def subtract_monomial_indices(
+ m1: MonomialData,
+ m2: MonomialData,
+) -> MonomialData:
m1_dict = dict(m1)
for index, count in m2: