diff options
-rw-r--r-- | polymatrix/polymatrix/typing.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py index 4baff5a..91544d1 100644 --- a/polymatrix/polymatrix/typing.py +++ b/polymatrix/polymatrix/typing.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import NamedTuple, Iterable, cast from collections import UserDict +from itertools import filterfalse # TODO: remove these types, they are here for backward compatiblity MonomialData = tuple[tuple[int, int], ...] @@ -11,6 +12,8 @@ PolynomialMatrixData = dict[tuple[int, int], PolynomialData] class VariableIndex(NamedTuple): """ Index for a variable raised to an integral power. + + `VariableIndex` has a total order with respect to the variable index. """ variable: int # index in ExpressionState object power: int @@ -23,6 +26,10 @@ class VariableIndex(NamedTuple): def constant() -> VariableIndex: return cast(VariableIndex, tuple()) + def is_constant(self) -> bool: + """ Returns true if it is indexing a constant the unit '1' variable. """ + return len(self) > 1 + class MonomialIndex(tuple[VariableIndex]): """ @@ -31,12 +38,56 @@ class MonomialIndex(tuple[VariableIndex]): @staticmethod def empty() -> MonomialIndex: + """ Get an empty monomial index. """ return MonomialIndex(tuple()) @staticmethod def constant() -> MonomialIndex: + """ Get the placeholder for constant terms. """ return MonomialIndex((VariableIndex.constant(),)) + def is_constant(self) -> bool: + """ Returns true if it is indexing a constant monomial. """ + if len(self) > 1: + return False + return self[0].is_constant() + + @staticmethod + def sort(index: MonomialIndex) -> MonomialIndex: + """ Sort the variable indices inside the monomial index. """ + return MonomialIndex(sorted(index)) + + @staticmethod + def product(left: MonomialIndex, right: MonomialIndex) -> MonomialIndex: + """ + Compute the index of the product of two monomials. + + For example if left is the index of :math:`xy` and right is the index + of :math:`y^2` this functions returns the index of :math:`xy^3`. + """ + if left.is_constant(): + return right + + if right.is_constant(): + return left + + # Compute the product of each non-constant term in left with each + # non-constant term in right, by using a dictionary {variable_index: power} + not_const_left = filterfalse(VariableIndex.is_constant, left) + result_dict: dict[int, int] = dict(not_const_left) + for idx, power in filterfalse(VariableIndex.is_constant, right): + if idx not in result_dict: + result_dict[idx] = power + else: + result_dict[idx] += power + + result = MonomialIndex(VariableIndex(k, v) for k, v in result_dict.items()) + return MonomialIndex.sort(result) + + @staticmethod + def differentiate(index: MonomialIndex, wrt: int) -> MonomialIndex | None: + raise NotImplementedError + class PolyDict(UserDict[MonomialIndex, int | float]): """ Polynomial, stored as a dictionary. """ |