From e2a9b4801dc15adb46346ba10a580ac86f7b39bd Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Sun, 3 Mar 2024 00:06:13 +0100
Subject: Add total ordering to PolyVarIndex and PolyIndex.product to compute
 index of products

---
 mdpoly/types.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 61 insertions(+), 4 deletions(-)

diff --git a/mdpoly/types.py b/mdpoly/types.py
index b580b99..4b75e8f 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -1,7 +1,9 @@
 from __future__ import annotations
 
+from .constants import NUMERICS_EPS
 from .errors import InvalidShape
 
+from itertools import filterfalse
 from typing import Self, NamedTuple, Iterable, TYPE_CHECKING
 
 if TYPE_CHECKING:
@@ -68,15 +70,28 @@ class PolyVarIndex(NamedTuple):
     """ 
     Tuple to index a power of a variable.
 
-    Concretely this represents things like x^2, y^4
+    Concretely this represents things like x^2, y^4 but not xy
     """
     var_idx: int # Index in State.variables, not variable!
     power: Number
 
+    def __eq__(self, other: Self):
+        if self.var_idx != other.var_idx:
+            return False
+
+        if (self.power - other.power) > NUMERICS_EPS:
+            return False
+        
+        return True
+
+    def __lt__(self, other: Self):
+        return self.var_idx < other.var_idx
+
     @classmethod
     @property
     def constant(cls):
-        """ Special index for constants.
+        """
+        Special index for constants.
 
         Constants do not have an associated variable,
         and the field power is ignored.
@@ -85,6 +100,7 @@ class PolyVarIndex(NamedTuple):
 
     @staticmethod
     def is_constant(index: Self) -> bool:
+        """ Check if is index of a constant term """
         return index.var_idx == -1
 
     @classmethod
@@ -94,7 +110,7 @@ class PolyVarIndex(NamedTuple):
 
 class PolyIndex(tuple[PolyVarIndex]):
     """
-    Tuple to index a coefficient of a polynomial.
+    Tuple to index a coefficient of a monomial in a polynomial.
 
     Example: 
 
@@ -116,8 +132,49 @@ class PolyIndex(tuple[PolyVarIndex]):
         indices = ", ".join(map(repr, self))
         return f"{name}({indices})"
 
+    @classmethod
+    def from_dict(cls, d) -> Self:
+        return cls(PolyVarIndex(k, v) for k, v in d.items())
+
     @classmethod
     @property
-    def constant(cls):
+    def constant(cls) -> Self:
         """ Index of the constant term """
         return cls((PolyVarIndex.constant,))
+
+    @staticmethod
+    def is_constant(index: Self) -> bool:
+        """ Check if is index of a constant term """
+        if len(index) != 1:
+            # FIXME: error message
+            raise IndexError(f"{index}")
+
+        return PolyVarIndex.is_constant(index[0])
+
+    @classmethod
+    def sort(cls, index: Self) -> Self:
+        """ Sort a tuple of indices """
+        return cls(sorted(index))
+
+    @classmethod
+    def product(cls, left: Self, right: Self) -> Self:
+        """
+        Compute index of product
+
+        For example if left is the index of xy and right is the index of y^2
+        this functions returns the index of xy^3
+        """
+        if cls.is_constant(left):
+            return right
+
+        if cls.is_constant(right):
+            return left
+
+        result: dict[int, Number] = dict(filterfalse(PolyVarIndex.is_constant, left))
+        for idx, power in filterfalse(PolyVarIndex.is_constant, right):
+            if idx not in result:
+                result[idx] = power
+            else:
+                result[idx] += power
+
+        return cls.sort(cls.from_dict(result))
-- 
cgit v1.2.1