summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/typing.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
index 4baff5a..91544d1 100644
--- a/polymatrix/polymatrix/typing.py
+++ b/polymatrix/polymatrix/typing.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import NamedTuple, Iterable, cast
from collections import UserDict
+from itertools import filterfalse
# TODO: remove these types, they are here for backward compatiblity
MonomialData = tuple[tuple[int, int], ...]
@@ -11,6 +12,8 @@ PolynomialMatrixData = dict[tuple[int, int], PolynomialData]
class VariableIndex(NamedTuple):
"""
Index for a variable raised to an integral power.
+
+ `VariableIndex` has a total order with respect to the variable index.
"""
variable: int # index in ExpressionState object
power: int
@@ -23,6 +26,10 @@ class VariableIndex(NamedTuple):
def constant() -> VariableIndex:
return cast(VariableIndex, tuple())
+ def is_constant(self) -> bool:
+ """ Returns true if it is indexing a constant the unit '1' variable. """
+ return len(self) > 1
+
class MonomialIndex(tuple[VariableIndex]):
"""
@@ -31,12 +38,56 @@ class MonomialIndex(tuple[VariableIndex]):
@staticmethod
def empty() -> MonomialIndex:
+ """ Get an empty monomial index. """
return MonomialIndex(tuple())
@staticmethod
def constant() -> MonomialIndex:
+ """ Get the placeholder for constant terms. """
return MonomialIndex((VariableIndex.constant(),))
+ def is_constant(self) -> bool:
+ """ Returns true if it is indexing a constant monomial. """
+ if len(self) > 1:
+ return False
+ return self[0].is_constant()
+
+ @staticmethod
+ def sort(index: MonomialIndex) -> MonomialIndex:
+ """ Sort the variable indices inside the monomial index. """
+ return MonomialIndex(sorted(index))
+
+ @staticmethod
+ def product(left: MonomialIndex, right: MonomialIndex) -> MonomialIndex:
+ """
+ Compute the index of the product of two monomials.
+
+ For example if left is the index of :math:`xy` and right is the index
+ of :math:`y^2` this functions returns the index of :math:`xy^3`.
+ """
+ if left.is_constant():
+ return right
+
+ if right.is_constant():
+ return left
+
+ # Compute the product of each non-constant term in left with each
+ # non-constant term in right, by using a dictionary {variable_index: power}
+ not_const_left = filterfalse(VariableIndex.is_constant, left)
+ result_dict: dict[int, int] = dict(not_const_left)
+ for idx, power in filterfalse(VariableIndex.is_constant, right):
+ if idx not in result_dict:
+ result_dict[idx] = power
+ else:
+ result_dict[idx] += power
+
+ result = MonomialIndex(VariableIndex(k, v) for k, v in result_dict.items())
+ return MonomialIndex.sort(result)
+
+ @staticmethod
+ def differentiate(index: MonomialIndex, wrt: int) -> MonomialIndex | None:
+ raise NotImplementedError
+
class PolyDict(UserDict[MonomialIndex, int | float]):
""" Polynomial, stored as a dictionary. """