From 6a9ca2e0c327281c1359b6152f6afc9165e6bf9b Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Tue, 27 Feb 2024 07:39:52 +0100
Subject: improve typing with PolynomialData and MonomialData

---
 polymatrix/__init__.py                             | 519 +--------------------
 polymatrix/denserepr/from_.py                      |   8 +-
 polymatrix/denserepr/utils/monomialtoindex.py      |  11 +
 .../expression/mixins/combinationsexprmixin.py     |   3 +-
 .../expression/mixins/derivativeexprmixin.py       |  14 +-
 .../expression/mixins/divergenceexprmixin.py       |  13 +-
 polymatrix/expression/mixins/elemmultexprmixin.py  |   2 +-
 .../expression/mixins/linearmonomialsexprmixin.py  |   2 +-
 .../expression/mixins/matrixmultexprmixin.py       |   2 +-
 polymatrix/expression/mixins/productexprmixin.py   |   4 +-
 .../expression/mixins/quadraticinexprmixin.py      |   2 +-
 .../mixins/quadraticmonomialsexprmixin.py          |   2 +-
 .../expression/mixins/substituteexprmixin.py       |   2 +-
 .../mixins/subtractmonomialsexprmixin.py           |   4 +-
 .../expression/utils/getderivativemonomials.py     |  13 +-
 polymatrix/expression/utils/getmonomialindices.py  |   5 +-
 polymatrix/expression/utils/gettupleremainder.py   |   9 -
 polymatrix/expression/utils/getvariableindices.py  |  50 +-
 .../expression/utils/mergemonomialindices.py       |  22 -
 polymatrix/expression/utils/monomialtoindex.py     |   7 -
 polymatrix/expression/utils/multiplymonomials.py   |  15 -
 polymatrix/expression/utils/multiplypolynomial.py  |  28 --
 polymatrix/expression/utils/sortmonomialindices.py |   5 -
 polymatrix/expression/utils/sortmonomials.py       |   5 -
 .../expression/utils/splitmonomialindices.py       |  24 -
 .../expression/utils/subtractmonomialindices.py    |  23 -
 polymatrix/polymatrix/impl.py                      |   7 +-
 polymatrix/polymatrix/init.py                      |   5 +-
 polymatrix/polymatrix/mixins.py                    |  14 +-
 polymatrix/polymatrix/typing.py                    |   8 +-
 .../polymatrix/utils/mergemonomialindices.py       |  38 ++
 polymatrix/polymatrix/utils/multiplypolynomial.py  |  32 ++
 polymatrix/polymatrix/utils/sortmonomialindices.py |  10 +
 polymatrix/polymatrix/utils/sortmonomials.py       |   9 +
 .../polymatrix/utils/splitmonomialindices.py       |  29 ++
 .../polymatrix/utils/subtractmonomialindices.py    |  28 ++
 36 files changed, 225 insertions(+), 749 deletions(-)
 create mode 100644 polymatrix/denserepr/utils/monomialtoindex.py
 delete mode 100644 polymatrix/expression/utils/gettupleremainder.py
 delete mode 100644 polymatrix/expression/utils/mergemonomialindices.py
 delete mode 100644 polymatrix/expression/utils/monomialtoindex.py
 delete mode 100644 polymatrix/expression/utils/multiplymonomials.py
 delete mode 100644 polymatrix/expression/utils/multiplypolynomial.py
 delete mode 100644 polymatrix/expression/utils/sortmonomialindices.py
 delete mode 100644 polymatrix/expression/utils/sortmonomials.py
 delete mode 100644 polymatrix/expression/utils/splitmonomialindices.py
 delete mode 100644 polymatrix/expression/utils/subtractmonomialindices.py
 create mode 100644 polymatrix/polymatrix/utils/mergemonomialindices.py
 create mode 100644 polymatrix/polymatrix/utils/multiplypolynomial.py
 create mode 100644 polymatrix/polymatrix/utils/sortmonomialindices.py
 create mode 100644 polymatrix/polymatrix/utils/sortmonomials.py
 create mode 100644 polymatrix/polymatrix/utils/splitmonomialindices.py
 create mode 100644 polymatrix/polymatrix/utils/subtractmonomialindices.py

diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 6592238..45968aa 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -2,541 +2,24 @@ from polymatrix.expressionstate.abc import ExpressionState as internal_Expressio
 from polymatrix.expressionstate.init import init_expression_state as internal_init_expression_state
 from polymatrix.expression.expression import Expression as internal_Expression
 from polymatrix.expression.from_ import from_ as internal_from
-# from polymatrix.expression.from_ import from_sympy as internal_from_sympy
 from polymatrix.expression import v_stack as internal_v_stack
 from polymatrix.expression import h_stack as internal_h_stack
 from polymatrix.expression import product as internal_product
-from polymatrix.denserepr.from_ import from_polymatrix
-# from polymatrix.expression.to import shape as internal_shape
 from polymatrix.expression.to import to_constant as internal_to_constant
-# from polymatrix.expression.to import to_degrees as internal_to_degrees
 from polymatrix.expression.to import to_sympy as internal_to_sympy
+from polymatrix.denserepr.from_ import from_polymatrix
 
 Expression = internal_Expression
 ExpressionState = internal_ExpressionState
 
 init_expression_state = internal_init_expression_state
 from_ = internal_from
-# from_sympy = internal_from_sympy
 v_stack = internal_v_stack
 h_stack = internal_h_stack
 product = internal_product
-# to_shape = internal_shape
 to_constant_repr = internal_to_constant
 to_constant = internal_to_constant
-# to_degrees = internal_to_degrees
 to_sympy_repr = internal_to_sympy
 to_sympy = internal_to_sympy
 to_matrix_repr = from_polymatrix
 to_dense = from_polymatrix
-
-# def from_sympy(
-#     data: tuple[tuple[float]],
-# ):
-#     return init_expression(
-#         init_from_sympy_expr(data)
-#     )
-
-# def from_state_monad(
-#     data: StateMonad,
-# ):
-#     return init_expression(
-#         data.flat_map(lambda inner_data: init_from_sympy_expr(inner_data)),
-#     )
-
-# def from_polymatrix(
-#     polymatrix: PolyMatrix,
-# ):
-#     return init_expression(
-#         init_from_terms_expr(polymatrix)
-#     )
-
-# def from_(
-#     data: str | tuple[tuple[float]],
-# ):
-#     if isinstance(data, str):
-#         return from_(1).parametrize(data)
-
-#     return from_sympy(data)
-
-# def v_stack(
-#     expressions: tuple[Expression],
-# ):
-#     return init_expression(
-#         init_v_stack_expr(expressions)
-#     )
-
-# def h_stack(
-#     expressions: tuple[Expression],
-# ):
-#     return init_expression(
-#         init_v_stack_expr(tuple(expr.T for expr in expressions))
-#     ).T
-
-# def block_diag(
-#     expressions: tuple[Expression],
-# ):
-#     return init_expression(
-#         init_block_diag_expr(expressions)
-#     )
-
-# def eye(
-#     variable: tuple[Expression],
-# ):
-#     return init_expression(
-#         init_eye_expr(variable=variable)
-#     )
-
-# def kkt_equality(
-#     equality: Expression,
-#     variables: Expression,
-#     name: str,
-# ):
-#     equality_der = equality.diff(
-#         variables, 
-#         introduce_derivatives=True,
-#     )
-
-#     nu = equality_der[0, :].T.parametrize(name)
-
-#     return nu, equality_der @ nu
-
-# def kkt_inequality(
-#     inequality: Expression,
-#     variables: Expression,
-#     name: str,
-# ):
-    
-#     inequality_der = inequality.diff(
-#         variables, 
-#         introduce_derivatives=True,
-#     )
-
-#     nu = inequality_der[0, :].T.parametrize(f'lambda_{name}')
-#     r_nu = inequality_der[0, :].T.parametrize(f'r_lambda_{name}')
-#     r_ineq = inequality_der[0, :].T.parametrize(f'r_ineq_{name}')
-
-#     prim_expr = inequality - r_ineq * r_ineq
-
-#     sec_expr = nu - r_nu * r_nu
-
-#     compl_expr = nu * inequality
-
-#     cost_expr = inequality_der @ nu
-
-#     return h_stack((nu, r_nu, r_ineq)), (cost_expr, prim_expr, sec_expr, compl_expr)
-
-
-# def rows(
-#     expr: Expression,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-#     def func(state: ExpressionState):
-#         state, underlying = expr.apply(state)
-
-#         def gen_row_terms():
-#             for row in range(underlying.shape[0]):
-
-#                 terms = {}
-
-#                 for col in range(underlying.shape[1]):
-#                     polynomial = underlying.get_poly(row, col)
-#                     if polynomial == None:
-#                         continue
-
-#                     terms[0, col] = polynomial
-
-#                 yield init_expression(underlying=init_from_terms_expr(
-#                     terms=terms,
-#                     shape=(1, underlying.shape[1])
-#                 ))
-
-#         row_terms = tuple(gen_row_terms())
-
-#         return state, row_terms
-
-#     return init_state_monad(func)
-
-
-# @dataclasses.dataclass
-# class MatrixBuffer:
-#     data: dict[int, np.ndarray]
-#     n_row: int
-#     n_param: int
-
-#     def get_max_degree(self):
-#         return max(degree for degree in self.data.keys())
-
-#     def add_buffer(self, index: int):
-#         if index <= 1:
-#             buffer = np.zeros((self.n_row, self.n_param**index), dtype=np.double)
-
-#         else:
-#             buffer = scipy.sparse.dok_array((self.n_row, self.n_param**index), dtype=np.double)
-            
-#         self.data[index] = buffer
-
-#     def add(self, row: int, col: int, index: int, value: float):
-#         if index not in self.data:
-#             self.add_buffer(index)
-
-#         self.data[index][row, col] = value
-
-#     def __getitem__(self, key):
-#         if key not in self.data:
-#             self.add_buffer(key)
-
-#         return self.data[key]
-
-
-# @dataclasses.dataclass
-# class MatrixRepresentations:
-#     data: tuple[MatrixBuffer, ...]
-#     aux_data: typing.Optional[MatrixBuffer]
-#     variable_mapping: tuple[int, ...]
-#     state: ExpressionState
-
-#     def merge_matrix_equations(self):
-#         def gen_matrices(index: int):
-#             for equations in self.data:
-#                 if index < len(equations):
-#                     yield equations[index]
-
-#             if index < len(self.aux_data):
-#                 yield self.aux_data[index]
-
-#         indices = set(key for equations in self.data + (self.aux_data,) for key in equations.keys())
-
-#         def gen_matrices():
-#             for index in indices:
-#                 if index <= 1:
-#                     yield index, np.vstack(tuple(gen_matrices(index)))
-#                 else:
-#                     yield index, scipy.sparse.vstack(tuple(gen_matrices(index)))
-
-#         return dict(gen_matrices())
-
-#     def get_value(self, variable, value):
-#         variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
-
-#         def gen_value_index():
-#             for variable_index in variable_indices:
-#                 try:
-#                     yield self.variable_mapping.index(variable_index)
-#                 except ValueError:
-#                     raise ValueError(f'{variable_index} not found in {self.variable_mapping}')
-
-#         value_index = list(gen_value_index())
-
-#         return value[value_index]
-
-#     def set_value(self, variable, value):
-#         variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
-#         value_index = list(self.variable_mapping.index(variable_index) for variable_index in variable_indices)
-#         vec = np.zeros(len(self.variable_mapping))
-#         vec[value_index] = value
-#         return vec
-
-#     # def get_matrix(self, eq_idx: int):
-#     #     equations = self.data[eq_idx].data
-
-#     def get_func(self, eq_idx: int):
-#         equations = self.data[eq_idx].data
-#         max_idx = max(equations.keys())
-
-#         if 2 <= max_idx:
-#             def func(x: np.ndarray) -> np.ndarray:
-#                 if isinstance(x, tuple) or isinstance(x, list):
-#                     x = np.array(x).reshape(-1, 1)
-
-#                 elif x.shape[0] == 1:
-#                     x = x.reshape(-1, 1)
-
-#                 def acc_x_powers(acc, _):
-#                     next = (acc @ x.T).reshape(-1, 1)
-#                     return next
-
-#                 x_powers = tuple(itertools.accumulate(
-#                     range(max_idx - 1),
-#                     acc_x_powers,
-#                     initial=x,
-#                 ))[1:]
-
-#                 def gen_value():
-#                     for idx, equation in equations.items():
-#                         if idx == 0:
-#                             yield equation
-
-#                         elif idx == 1:
-#                             yield equation @ x
-
-#                         else:
-#                             yield equation @ x_powers[idx-2]
-
-#                 return sum(gen_value())
-
-#         else:
-#             def func(x: np.ndarray) -> np.ndarray:
-#                 if isinstance(x, tuple) or isinstance(x, list):
-#                     x = np.array(x).reshape(-1, 1)
-            
-#                 def gen_value():
-#                     for idx, equation in equations.items():
-#                         if idx == 0:
-#                             yield equation
-
-#                         else:
-#                             yield equation @ x
-
-#                 return sum(gen_value())
-
-#         return func
-
-# def to_matrix_repr(
-#     expressions: Expression | tuple[Expression],
-#     variables: Expression = None,
-#     sorted: bool = None,
-# ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
-
-#     if isinstance(expressions, Expression):
-#         expressions = (expressions,)
-
-#     assert isinstance(variables, ExpressionBaseMixin) or variables is None, f'{variables=}'
-
-#     def func(state: ExpressionState):
-
-#         def acc_underlying_application(acc, v):
-#             state, underlying_list = acc
-
-#             state, underlying = v.apply(state)
-
-#             assert underlying.shape[1] == 1, f'{underlying.shape[1]=} is not 1'
-            
-#             return state, underlying_list + (underlying,)
-
-#         *_, (state, polymatrix_list) = tuple(itertools.accumulate(
-#             expressions, 
-#             acc_underlying_application, 
-#             initial=(state, tuple()),
-#         ))
-
-#         if variables is None:
-#             sorted_variable_index = tuple()
-            
-#         else:
-#             state, variable_index = get_variable_indices_from_variable(state, variables)
-
-#             if sorted:
-#                 tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
-
-#                 sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
-                
-#             else:
-#                 sorted_variable_index = variable_index
-
-#             sorted_variable_index_set = set(sorted_variable_index)
-#             if len(sorted_variable_index) != len(sorted_variable_index_set):
-#                 duplicates = tuple(state.get_name_from_offset(var) for var in sorted_variable_index_set if 1 < sorted_variable_index.count(var))
-#                 raise Exception(f'{duplicates=}. Make sure you give a unique name for each variables.')
-        
-#         variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)}
-
-#         n_param = len(sorted_variable_index)
-
-#         def gen_numpy_matrices():
-#             for polymatrix in polymatrix_list:
-
-#                 n_row = polymatrix.shape[0]
-
-#                 buffer = MatrixBuffer(
-#                     data={},
-#                     n_row=n_row,
-#                     n_param=n_param,
-#                 )
-
-#                 for row in range(n_row):
-#                     polymatrix_terms = polymatrix.get_poly(row, 0)
-                    
-#                     if polymatrix_terms is None:
-#                         continue
-
-#                     if len(polymatrix_terms) == 0:
-#                         buffer.add(row, 0, 0, 0)
-                    
-#                     else:
-#                         for monomial, value in polymatrix_terms.items():
-
-#                             def gen_new_monomial():
-#                                 for var, count in monomial:
-#                                     try:
-#                                         index = variable_index_map[var]
-#                                     except KeyError:
-#                                         raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
-
-#                                     for _ in range(count):
-#                                         yield index
-
-#                             new_monomial = tuple(gen_new_monomial())
-
-#                             cols = monomial_to_index(n_param, new_monomial)
-
-#                             col_value = value / len(cols)
-
-#                             for col in cols:
-#                                 degree = sum(count for _, count in monomial)
-#                                 buffer.add(row, col, degree, col_value)
-                
-#                 yield buffer
-
-#         underlying_matrices = tuple(gen_numpy_matrices())
-
-#         def gen_auxillary_equations():
-#             for key, monomial_terms in state.auxillary_equations.items():
-#                 if key in sorted_variable_index:
-#                     yield key, monomial_terms
-
-#         auxillary_equations = tuple(gen_auxillary_equations())
-
-#         n_row = len(auxillary_equations)
-
-#         if n_row == 0:
-#             auxillary_matrix_equations = None
-
-#         else:
-#             buffer = MatrixBuffer(
-#                 data={},
-#                 n_row=n_row,
-#                 n_param=n_param,
-#             )
-
-#             for row, (key, monomial_terms) in enumerate(auxillary_equations):
-#                 for monomial, value in monomial_terms.items():
-#                     new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count))
-
-#                     cols = monomial_to_index(n_param, new_monomial)
-
-#                     col_value = value / len(cols)
-
-#                     for col in cols:
-#                         buffer.add(row, col, sum(count for _, count in monomial), col_value)
-
-#             auxillary_matrix_equations = buffer
-
-#         result = MatrixRepresentations(
-#             data=underlying_matrices,
-#             aux_data=auxillary_matrix_equations,
-#             variable_mapping=sorted_variable_index,
-#             state=state,
-#         )
-
-#         return state, result
-
-#     return init_state_monad(func)
-
-
-# def to_constant_repr(
-#     expr: Expression,
-#     assert_constant: bool = True,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-#     def func(state: ExpressionState):
-#         state, underlying = expr.apply(state)
-
-#         A = np.zeros(underlying.shape, dtype=np.double)
-
-#         for (row, col), polynomial in underlying.gen_terms():
-#             for monomial, value in polynomial.items():
-#                 if len(monomial) == 0:
-#                     A[row, col] = value
-                
-#                 elif assert_constant:
-#                     raise Exception(f'non-constant term {monomial=}')
-
-#         return state, A
-
-#     return init_state_monad(func)
-
-
-# def degrees(
-#     expr: Expression,
-#     variables: Expression,
-# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
-
-#     def func(state: ExpressionState):
-#         state, underlying = expr.apply(state)
-#         state, variable_indices = get_variable_indices_from_variable(state, variables)
-
-#         def gen_rows():
-#             for row in range(underlying.shape[0]):
-#                 def gen_cols():
-#                     for col in range(underlying.shape[1]):
-                        
-#                         def gen_degrees():
-#                             polynomial = underlying.get_poly(row, col)
-
-#                             if polynomial is None:
-#                                 yield 0
-
-#                             else:
-#                                 for monomial, _ in polynomial.items():
-#                                     yield sum(count for var, count in monomial if var in variable_indices)
-
-#                         yield tuple(set(gen_degrees()))
-
-#                 yield tuple(gen_cols())
-
-#         return state, tuple(gen_rows())
-
-#     return init_state_monad(func)
-
-
-# def to_sympy_repr(
-#     expr: Expression,
-# ) -> StateMonadMixin[ExpressionState, sympy.Expr]:
-
-#     def func(state: ExpressionState):
-#         state, underlying = expr.apply(state)
-
-#         A = np.zeros(underlying.shape, dtype=object)
-
-#         for (row, col), polynomial in underlying.gen_terms():
-
-#             sympy_polynomial = 0
-
-#             for monomial, value in polynomial.items():
-#                 sympy_monomial = 1
-
-#                 for offset, count in monomial:
-
-#                     variable = state.get_key_from_offset(offset)
-#                     # def get_variable_from_offset(offset: int):
-#                     #     for variable, (start, end) in state.offset_dict.items():
-#                     #         if start <= offset < end:
-#                                 # assert end - start == 1, f'{start=}, {end=}, {variable=}'
-
-#                     if isinstance(variable, sympy.core.symbol.Symbol):
-#                         variable_name = variable.name
-#                     elif isinstance(variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)):
-#                         variable_name = variable.name
-#                     elif isinstance(variable, str):
-#                         variable_name = variable
-#                     else:
-#                         raise Exception(f'{variable=}')
-
-#                     start, end = state.offset_dict[variable]
-
-#                     if end - start == 1:
-#                         sympy_var = sympy.Symbol(variable_name)
-#                     else:
-#                         sympy_var = sympy.Symbol(f'{variable_name}_{offset - start + 1}')
-
-#                     # var = get_variable_from_offset(offset)
-#                     sympy_monomial *= sympy_var**count
-
-#                 sympy_polynomial += value * sympy_monomial
-
-#             A[row, col] = sympy_polynomial
-
-#         return state, A
-
-#     return init_state_monad(func)
diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py
index e44228f..7bafb5d 100644
--- a/polymatrix/denserepr/from_.py
+++ b/polymatrix/denserepr/from_.py
@@ -9,7 +9,7 @@ from polymatrix.expressionstate.abc import ExpressionState
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
 from polymatrix.statemonad.init import init_state_monad
 from polymatrix.statemonad.mixins import StateMonadMixin
-from polymatrix.expression.utils.monomialtoindex import monomial_to_index
+from polymatrix.denserepr.utils.monomialtoindex import variable_indices_to_column_index
 
 
 def from_polymatrix(
@@ -97,9 +97,9 @@ def from_polymatrix(
                                     for _ in range(count):
                                         yield index
 
-                            new_monomial = tuple(gen_new_monomial())
+                            new_variable_indices = tuple(gen_new_monomial())
 
-                            cols = monomial_to_index(n_param, new_monomial)
+                            cols = variable_indices_to_column_index(n_param, new_variable_indices)
 
                             col_value = value / len(cols)
 
@@ -134,7 +134,7 @@ def from_polymatrix(
                 for monomial, value in monomial_terms.items():
                     new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count))
 
-                    cols = monomial_to_index(n_param, new_monomial)
+                    cols = variable_indices_to_column_index(n_param, new_monomial)
 
                     col_value = value / len(cols)
 
diff --git a/polymatrix/denserepr/utils/monomialtoindex.py b/polymatrix/denserepr/utils/monomialtoindex.py
new file mode 100644
index 0000000..c13adfd
--- /dev/null
+++ b/polymatrix/denserepr/utils/monomialtoindex.py
@@ -0,0 +1,11 @@
+import itertools
+
+
+def variable_indices_to_column_index(
+        n_var: int, 
+        variable_indices: tuple[int, ...],
+) -> int:
+    
+    variable_indices_perm = itertools.permutations(variable_indices)
+    
+    return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in variable_indices_perm)
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
index da669f4..9da69af 100644
--- a/polymatrix/expression/mixins/combinationsexprmixin.py
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -1,13 +1,12 @@
 
 import abc
 import itertools
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
 
 from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplymonomials import multiply_monomials
 
 
 class CombinationsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 1728a2d..38ea9cc 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
+from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
 from polymatrix.utils.getstacklines import FrameSummary
 from polymatrix.utils.tooperatorexception import to_operator_exception
@@ -60,23 +60,23 @@ class DerivativeExprMixin(ExpressionBaseMixin):
 
         for row in range(underlying.shape[0]):
 
-            underlying_terms = underlying.get_poly(row, 0)
-            if underlying_terms is None:
+            polynomial = underlying.get_poly(row, 0)
+            if polynomial is None:
                 continue
 
             # derivate each variable and map result to the corresponding column
             for col, diff_wrt_variable in enumerate(diff_wrt_variables):
 
-                state, derivation_terms = get_derivative_monomials(
-                    monomial_terms=underlying_terms,
+                state, derivation = differentiate_polynomial(
+                    polynomial=polynomial,
                     diff_wrt_variable=diff_wrt_variable,
                     state=state,
                     considered_variables=set(diff_wrt_variables),
                     introduce_derivatives=self.introduce_derivatives,
                 )
 
-                if 0 < len(derivation_terms):
-                    terms[row, col] = derivation_terms
+                if 0 < len(derivation):
+                    terms[row, col] = derivation
 
         poly_matrix = init_poly_matrix(
             terms=terms,
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 29d1a0a..10db5b4 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
+from polymatrix.expression.utils.getderivativemonomials import differentiate_polynomial
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
 
 
@@ -38,19 +38,20 @@ class DivergenceExprMixin(ExpressionBaseMixin):
 
         for row, variable in enumerate(variables):
 
-            underlying_terms = underlying.get_poly(row, 0)
-            if underlying_terms is None:
+            polynomial = underlying.get_poly(row, 0)
+
+            if polynomial is None:
                 continue
 
-            state, derivation_terms = get_derivative_monomials(
-                monomial_terms=underlying_terms,
+            state, derivation = differentiate_polynomial(
+                polynomial=polynomial,
                 diff_wrt_variable=variable,
                 state=state,
                 considered_variables=set(),
                 introduce_derivatives=False,
             )
 
-            for monomial, value in derivation_terms.items():
+            for monomial, value in derivation.items():
                 monomial_terms[monomial] += value
 
         poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index e6e64b1..17be163 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
+from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices
 
 
 class ElemMultExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index 516a7a4..ed9e79f 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.expressionstate.mixins import ExpressionStateMixin
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.sortmonomials import sort_monomials
+from polymatrix.polymatrix.utils.sortmonomials import sort_monomials
 
 
 class LinearMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index 3343a65..b7ae5ce 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -6,7 +6,7 @@ from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
 from polymatrix.utils.tooperatorexception import to_operator_exception
 
 
diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py
index 45be74b..0168e27 100644
--- a/polymatrix/expression/mixins/productexprmixin.py
+++ b/polymatrix/expression/mixins/productexprmixin.py
@@ -1,14 +1,12 @@
 
 import abc
 import itertools
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
 
 from polymatrix.utils.getstacklines import FrameSummary
-from polymatrix.utils.tooperatorexception import to_operator_exception
 from polymatrix.polymatrix.init import init_poly_matrix
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
-from polymatrix.expression.utils.multiplymonomials import multiply_monomials
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 
 
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index 19c1610..e739364 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -8,7 +8,7 @@ from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
 from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
+from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices
 from polymatrix.utils.getstacklines import FrameSummary
 from polymatrix.utils.tooperatorexception import to_operator_exception
 
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index ff16275..99befc8 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.expressionstate.mixins import ExpressionStateMixin
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
+from polymatrix.polymatrix.utils.splitmonomialindices import split_monomial_indices
 
 
 class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index e1c0c3d..b75a2de 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -9,7 +9,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.polymatrix.abc import PolyMatrix
 from polymatrix.expressionstate.abc import ExpressionState
 from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
-from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
+from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
 
 
 class SubstituteExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
index 37add0a..9275332 100644
--- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
@@ -7,8 +7,8 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 from polymatrix.expressionstate.mixins import ExpressionStateMixin
 from polymatrix.polymatrix.mixins import PolyMatrixMixin
 from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
-from polymatrix.expression.utils.sortmonomials import sort_monomials
-from polymatrix.expression.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices
+from polymatrix.polymatrix.utils.sortmonomials import sort_monomials
+from polymatrix.polymatrix.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices
 
 
 class SubtractMonomialsExprMixin(ExpressionBaseMixin):
diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py
index a8b2e16..83083aa 100644
--- a/polymatrix/expression/utils/getderivativemonomials.py
+++ b/polymatrix/expression/utils/getderivativemonomials.py
@@ -3,10 +3,11 @@ import dataclasses
 import itertools
 
 from polymatrix.expressionstate.abc import ExpressionState
+from polymatrix.polymatrix.typing import PolynomialData
 
 
-def get_derivative_monomials(
-    monomial_terms, 
+def differentiate_polynomial(
+    polynomial: PolynomialData, 
     diff_wrt_variable: int, 
     state: ExpressionState,
     considered_variables: set,
@@ -21,7 +22,7 @@ def get_derivative_monomials(
     if introduce_derivatives:
         
         def gen_new_variables():
-            for monomial in monomial_terms.keys():
+            for monomial in polynomial.keys():
                 for var in monomial:
                     if var is not diff_wrt_variable and var not in considered_variables:
                         yield var
@@ -42,8 +43,8 @@ def get_derivative_monomials(
             state = state.register(key=key, n_param=1)
 
             # for each new variable we expect an auxillary equation
-            state, auxillary_derivation_terms = get_derivative_monomials(
-                monomial_terms=state.auxillary_equations[new_variable],
+            state, auxillary_derivation_terms = differentiate_polynomial(
+                polynomial=state.auxillary_equations[new_variable],
                 diff_wrt_variable=diff_wrt_variable,
                 state=state,
                 considered_variables=new_considered_variables,
@@ -75,7 +76,7 @@ def get_derivative_monomials(
 
     derivation_terms = collections.defaultdict(float)
 
-    for monomial, value in monomial_terms.items():
+    for monomial, value in polynomial.items():
 
         monomial_cnt = dict(monomial)
 
diff --git a/polymatrix/expression/utils/getmonomialindices.py b/polymatrix/expression/utils/getmonomialindices.py
index 42f9ca7..dcb7932 100644
--- a/polymatrix/expression/utils/getmonomialindices.py
+++ b/polymatrix/expression/utils/getmonomialindices.py
@@ -2,7 +2,10 @@ from polymatrix.expressionstate.abc import ExpressionState
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 
 
-def get_monomial_indices(state: ExpressionState, monomials: ExpressionBaseMixin) -> tuple[ExpressionState, tuple[int, ...]]:
+def get_monomial_indices(
+        state: ExpressionState, 
+        monomials: ExpressionBaseMixin,
+    ) -> tuple[ExpressionState, tuple[int, ...]]:
 
     state, monomials_obj = monomials.apply(state)
 
diff --git a/polymatrix/expression/utils/gettupleremainder.py b/polymatrix/expression/utils/gettupleremainder.py
deleted file mode 100644
index 2a22ead..0000000
--- a/polymatrix/expression/utils/gettupleremainder.py
+++ /dev/null
@@ -1,9 +0,0 @@
-def get_tuple_remainder(l1, l2):
-    # create a copy of l1
-    l1_copy = list(l1)
-    
-    # remove from the copy all elements of l2
-    for e in l2:
-        l1_copy.remove(e)
-        
-    return tuple(l1_copy)
\ No newline at end of file
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 4ee7037..d61da57 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,9 +1,14 @@
 import itertools
+import typing
 
+from polymatrix.expressionstate.abc import ExpressionState
 from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
 
 
-def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
+def get_variable_indices_from_variable(
+        state: ExpressionState, 
+        variable: ExpressionBaseMixin | int | typing.Any,
+) -> tuple[int, ...] | None:
     
     if isinstance(variable, ExpressionBaseMixin):
         state, variable_polynomial = variable.apply(state)
@@ -40,13 +45,12 @@ def get_variable_indices_from_variable(state, variable) -> tuple[int] | None:
     return state, variable_indices
 
 
+# not used, remove?
 def get_variable_indices(state, variables):
 
     if not isinstance(variables, tuple):
         variables = (variables,)
 
-    # assert isinstance(variables, tuple), f'{variables=}'
-
     def acc_variable_indices(acc, variable):
         state, indices = acc
 
@@ -61,43 +65,3 @@ def get_variable_indices(state, variables):
     )
 
     return state, indices
-
-
-    # global_state = [state]
-
-    # assert isinstance(variables, tuple), f'{variables=}'
-
-    # # if not isinstance(variables, tuple):
-    # #     variables = (variables,)
-
-    # def gen_indices():
-    #     for variable in variables:
-    #         if isinstance(variable, ExpressionBaseMixin):
-    #             global_state[0], variable_polynomial = variable.apply(global_state[0])
-
-    #             assert variable_polynomial.shape[1] == 1
-
-    #             for row in range(variable_polynomial.shape[0]):
-    #                 row_terms = variable_polynomial.get_poly(row, 0)
-
-    #                 assert len(row_terms) == 1, f'{row_terms} contains more than one term'
-                
-    #                 for monomial in row_terms.keys():
-    #                     assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
-
-    #                     if len(monomial) == 0:
-    #                         continue
-                    
-    #                     assert monomial[0][1] == 1, f'{monomial[0]=}'
-    #                     yield monomial[0][0]
-
-    #         elif isinstance(variable, int):
-    #             yield variable
-
-    #         # else:
-    #         elif variable in global_state[0].offset_dict:
-    #             yield global_state[0].offset_dict[variable][0]
-
-    # indices = tuple(gen_indices())
-
-    # return global_state[0], indices
diff --git a/polymatrix/expression/utils/mergemonomialindices.py b/polymatrix/expression/utils/mergemonomialindices.py
deleted file mode 100644
index c572852..0000000
--- a/polymatrix/expression/utils/mergemonomialindices.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices
-
-
-def merge_monomial_indices(monomials):
-    if len(monomials) == 0:
-        return tuple()
-
-    elif len(monomials) == 1:
-        return monomials[0]
-
-    else:
-
-        m1_dict = dict(monomials[0])
-
-        for other in monomials[1:]:
-            for index, count in other:
-                if index in m1_dict:
-                    m1_dict[index] += count
-                else:
-                    m1_dict[index] = count
-
-        return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/expression/utils/monomialtoindex.py b/polymatrix/expression/utils/monomialtoindex.py
deleted file mode 100644
index bb1e570..0000000
--- a/polymatrix/expression/utils/monomialtoindex.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import itertools
-
-
-def monomial_to_index(n_var, monomial):
-    monomial_perm = itertools.permutations(monomial)
-    
-    return set(sum(idx*(n_var**level) for level, idx in enumerate(monomial)) for monomial in monomial_perm)
diff --git a/polymatrix/expression/utils/multiplymonomials.py b/polymatrix/expression/utils/multiplymonomials.py
deleted file mode 100644
index a4b3273..0000000
--- a/polymatrix/expression/utils/multiplymonomials.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import itertools
-import math
-
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
-
-
-def multiply_monomials(left_monomials, right_monomials):
-    def gen_monomials():
-        for left_monomial, right_monomial in itertools.product(left_monomials, right_monomials):
-            
-            monomial = merge_monomial_indices((left_monomial, right_monomial))
-
-            yield monomial
-
-    return gen_monomials()
diff --git a/polymatrix/expression/utils/multiplypolynomial.py b/polymatrix/expression/utils/multiplypolynomial.py
deleted file mode 100644
index e27a124..0000000
--- a/polymatrix/expression/utils/multiplypolynomial.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import itertools
-import math
-
-from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
-from polymatrix.polymatrix.typing import PolynomialData
-
-def multiply_polynomial(left: PolynomialData, right: PolynomialData, result: PolynomialData):
-    """
-    Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`.
-    """
-    
-    for (left_monomial, left_value), (right_monomial, right_value) \
-            in itertools.product(left.items(), right.items()):
-
-        value = left_value * right_value
-
-        if math.isclose(value, 0, abs_tol=1e-12):
-            continue
-        
-        monomial = merge_monomial_indices((left_monomial, right_monomial))
-
-        if monomial not in result:
-            result[monomial] = 0
-
-        result[monomial] += value
-
-        if math.isclose(result[monomial], 0, abs_tol=1e-12):
-            del result[monomial]
diff --git a/polymatrix/expression/utils/sortmonomialindices.py b/polymatrix/expression/utils/sortmonomialindices.py
deleted file mode 100644
index 3d4734e..0000000
--- a/polymatrix/expression/utils/sortmonomialindices.py
+++ /dev/null
@@ -1,5 +0,0 @@
-def sort_monomial_indices(monomial):
-    return tuple(sorted(
-        monomial, 
-        key=lambda m: m[0],
-    ))
\ No newline at end of file
diff --git a/polymatrix/expression/utils/sortmonomials.py b/polymatrix/expression/utils/sortmonomials.py
deleted file mode 100644
index 46e7118..0000000
--- a/polymatrix/expression/utils/sortmonomials.py
+++ /dev/null
@@ -1,5 +0,0 @@
-def sort_monomials(monomials):
-    return tuple(sorted(
-        monomials, 
-        key=lambda m: (sum(count for _, count in m), len(m), m),
-    ))
\ No newline at end of file
diff --git a/polymatrix/expression/utils/splitmonomialindices.py b/polymatrix/expression/utils/splitmonomialindices.py
deleted file mode 100644
index 4ffa6ff..0000000
--- a/polymatrix/expression/utils/splitmonomialindices.py
+++ /dev/null
@@ -1,24 +0,0 @@
-def split_monomial_indices(monomial):
-    left = []
-    right = []
-
-    is_left = True
-
-    for idx, count in monomial:  
-        count_left = count // 2
-
-        if count % 2:
-            if is_left:
-                count_left = count_left + 1
-                
-            is_left = not is_left
-            
-        count_right = count - count_left
-
-        if 0 < count_left:
-            left.append((idx, count_left))
-
-        if 0 < count_right:
-            right.append((idx, count - count_left))
-    
-    return tuple(left), tuple(right)
\ No newline at end of file
diff --git a/polymatrix/expression/utils/subtractmonomialindices.py b/polymatrix/expression/utils/subtractmonomialindices.py
deleted file mode 100644
index 236ba9c..0000000
--- a/polymatrix/expression/utils/subtractmonomialindices.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices
-
-
-class SubtractError(Exception):
-    pass
-
-
-def subtract_monomial_indices(m1, m2):
-    m1_dict = dict(m1)
-
-    for index, count in m2:
-        if index not in m1_dict:
-            raise SubtractError()
-
-        m1_dict[index] -= count
-
-        if m1_dict[index] == 0:
-            del m1_dict[index]
-
-        elif m1_dict[index] < 0:
-            raise SubtractError()
-
-    return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index fe5946e..f957685 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -1,13 +1,14 @@
 import dataclassabc
 
 from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin 
+from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin
+from polymatrix.polymatrix.typing import PolynomialData 
 
 
 @dataclassabc.dataclassabc(frozen=True)
 class PolyMatrixImpl(PolyMatrix):
-	terms: dict
-	shape: tuple[int, ...]
+	terms: dict[tuple[int, int], PolynomialData]
+	shape: tuple[int, int]
 
 
 @dataclassabc.dataclassabc(frozen=True)
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 7f0b60f..2dc2a63 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -1,9 +1,10 @@
 from polymatrix.polymatrix.impl import PolyMatrixImpl
+from polymatrix.polymatrix.typing import PolynomialData
 
 
 def init_poly_matrix(
-		terms: dict,
-		shape: tuple,
+		terms: dict[tuple[int, int], PolynomialData],
+		shape: tuple[int, int],
 ):
 
 	return PolyMatrixImpl(
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index 1aa0f19..5046492 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -1,6 +1,8 @@
 import abc
 import typing
 
+from polymatrix.polymatrix.typing import PolynomialData
+
 
 class PolyMatrixMixin(abc.ABC):
     @property
@@ -8,7 +10,7 @@ class PolyMatrixMixin(abc.ABC):
     def shape(self) -> tuple[int, int]:
         ...
 
-    def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], dict[tuple[int, ...], float]], None, None]:
+    def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]:
         for row in range(self.shape[0]):
             for col in range(self.shape[1]):
                 polynomial = self.get_poly(row, col)
@@ -18,7 +20,7 @@ class PolyMatrixMixin(abc.ABC):
                 yield (row, col), polynomial
 
     @abc.abstractclassmethod
-    def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
+    def get_poly(self, row: int, col: int) -> PolynomialData | None:
         ...
 
 
@@ -28,11 +30,11 @@ class PolyMatrixAsDictMixin(
 ):
     @property
     @abc.abstractmethod
-    def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]:
+    def terms(self) -> dict[tuple[int, int], PolynomialData]:
         ...
 
     # overwrites the abstract method of `PolyMatrixMixin`
-    def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
+    def get_poly(self, row: int, col: int) -> PolynomialData | None:
         if (row, col) in self.terms:
             return self.terms[row, col]
 
@@ -43,9 +45,9 @@ class BroadcastPolyMatrixMixin(
 ):
     @property
     @abc.abstractmethod
-    def polynomial(self) -> dict[tuple[int, ...], float]:
+    def polynomial(self) -> PolynomialData:
         ...
 
     # overwrites the abstract method of `PolyMatrixMixin`
-    def get_poly(self, col: int, row: int) -> dict[tuple[int, ...], float] | None:
+    def get_poly(self, col: int, row: int) -> PolynomialData | None:
         return self.polynomial
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
index f8031e5..0540054 100644
--- a/polymatrix/polymatrix/typing.py
+++ b/polymatrix/polymatrix/typing.py
@@ -1,4 +1,8 @@
 
+# monomial x1**2 x2
+# with indices {x1: 0, x2: 1}
+# is represented as ((0, 2), (1, 1))
+MonomialData = tuple[tuple[int, int], ...]
 
-PolynomialData = dict[tuple[int, ...], float]
-PolynomialMatrixData = dict[tuple[int, int], dict[tuple[int, ...], float]]
\ No newline at end of file
+PolynomialData = dict[MonomialData, float]
+PolynomialMatrixData = dict[tuple[int, int], dict[MonomialData, float]]
\ No newline at end of file
diff --git a/polymatrix/polymatrix/utils/mergemonomialindices.py b/polymatrix/polymatrix/utils/mergemonomialindices.py
new file mode 100644
index 0000000..8285c98
--- /dev/null
+++ b/polymatrix/polymatrix/utils/mergemonomialindices.py
@@ -0,0 +1,38 @@
+from polymatrix.polymatrix.typing import MonomialData
+from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices
+
+
+def merge_monomial_indices(
+    monomials: tuple[MonomialData, ...],
+) -> MonomialData:
+    """
+    (x1**2 x2, x2**2)  ->  x1**2 x2**3
+
+    or in terms of indices {x1: 0, x2: 1}:
+
+        (
+            ((0, 2), (1, 1)),      # x1**2 x2
+            ((1, 2),)              # x2**2
+        )  ->  ((0, 2), (1, 3))    # x1**1 x2**3
+    """
+    
+    if len(monomials) == 0:
+        return tuple()
+
+    elif len(monomials) == 1:
+        return monomials[0]
+
+    else:
+
+        m1_dict = dict(monomials[0])
+
+        for other in monomials[1:]:
+            for index, count in other:
+                if index in m1_dict:
+                    m1_dict[index] += count
+                else:
+                    m1_dict[index] = count
+
+        # sort monomials according to their index
+        # ((1, 3), (0, 2))  ->  ((0, 2), (1, 3))
+        return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/polymatrix/utils/multiplypolynomial.py b/polymatrix/polymatrix/utils/multiplypolynomial.py
new file mode 100644
index 0000000..17fd154
--- /dev/null
+++ b/polymatrix/polymatrix/utils/multiplypolynomial.py
@@ -0,0 +1,32 @@
+import itertools
+import math
+
+from polymatrix.polymatrix.utils.mergemonomialindices import merge_monomial_indices
+from polymatrix.polymatrix.typing import PolynomialData
+
+def multiply_polynomial(
+        left: PolynomialData, 
+        right: PolynomialData, 
+        result: PolynomialData,
+    ) -> None:
+    """
+    Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`.
+    """
+    
+    for (left_monomial, left_value), (right_monomial, right_value) \
+            in itertools.product(left.items(), right.items()):
+
+        value = left_value * right_value
+
+        if math.isclose(value, 0, abs_tol=1e-12):
+            continue
+        
+        monomial = merge_monomial_indices((left_monomial, right_monomial))
+
+        if monomial not in result:
+            result[monomial] = 0
+
+        result[monomial] += value
+
+        if math.isclose(result[monomial], 0, abs_tol=1e-12):
+            del result[monomial]
diff --git a/polymatrix/polymatrix/utils/sortmonomialindices.py b/polymatrix/polymatrix/utils/sortmonomialindices.py
new file mode 100644
index 0000000..2a13bc3
--- /dev/null
+++ b/polymatrix/polymatrix/utils/sortmonomialindices.py
@@ -0,0 +1,10 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+
+def sort_monomial_indices(
+        monomial: MonomialData,
+) -> MonomialData:
+    return tuple(sorted(
+        monomial, 
+        key=lambda m: m[0],
+    ))
diff --git a/polymatrix/polymatrix/utils/sortmonomials.py b/polymatrix/polymatrix/utils/sortmonomials.py
new file mode 100644
index 0000000..e1c82e7
--- /dev/null
+++ b/polymatrix/polymatrix/utils/sortmonomials.py
@@ -0,0 +1,9 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+def sort_monomials(
+        monomials: tuple[MonomialData],
+) -> tuple[MonomialData]:
+    return tuple(sorted(
+        monomials, 
+        key=lambda m: (sum(count for _, count in m), len(m), m),
+    ))
diff --git a/polymatrix/polymatrix/utils/splitmonomialindices.py b/polymatrix/polymatrix/utils/splitmonomialindices.py
new file mode 100644
index 0000000..8d5d148
--- /dev/null
+++ b/polymatrix/polymatrix/utils/splitmonomialindices.py
@@ -0,0 +1,29 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+
+def split_monomial_indices(
+        monomial: MonomialData,
+) -> tuple[MonomialData, MonomialData]:
+    left = []
+    right = []
+
+    is_left = True
+
+    for idx, count in monomial:  
+        count_left = count // 2
+
+        if count % 2:
+            if is_left:
+                count_left = count_left + 1
+                
+            is_left = not is_left
+            
+        count_right = count - count_left
+
+        if 0 < count_left:
+            left.append((idx, count_left))
+
+        if 0 < count_right:
+            right.append((idx, count - count_left))
+    
+    return tuple(left), tuple(right)
diff --git a/polymatrix/polymatrix/utils/subtractmonomialindices.py b/polymatrix/polymatrix/utils/subtractmonomialindices.py
new file mode 100644
index 0000000..c62bf3e
--- /dev/null
+++ b/polymatrix/polymatrix/utils/subtractmonomialindices.py
@@ -0,0 +1,28 @@
+from polymatrix.polymatrix.typing import MonomialData
+
+from polymatrix.polymatrix.utils.sortmonomialindices import sort_monomial_indices
+
+
+class SubtractError(Exception):
+    pass
+
+
+def subtract_monomial_indices(
+        m1: MonomialData, 
+        m2: MonomialData,
+) -> MonomialData:
+    m1_dict = dict(m1)
+
+    for index, count in m2:
+        if index not in m1_dict:
+            raise SubtractError()
+
+        m1_dict[index] -= count
+
+        if m1_dict[index] == 0:
+            del m1_dict[index]
+
+        elif m1_dict[index] < 0:
+            raise SubtractError()
+
+    return sort_monomial_indices(m1_dict.items())
-- 
cgit v1.2.1