aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/types.py65
1 files changed, 61 insertions, 4 deletions
diff --git a/mdpoly/types.py b/mdpoly/types.py
index b580b99..4b75e8f 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -1,7 +1,9 @@
from __future__ import annotations
+from .constants import NUMERICS_EPS
from .errors import InvalidShape
+from itertools import filterfalse
from typing import Self, NamedTuple, Iterable, TYPE_CHECKING
if TYPE_CHECKING:
@@ -68,15 +70,28 @@ class PolyVarIndex(NamedTuple):
"""
Tuple to index a power of a variable.
- Concretely this represents things like x^2, y^4
+ Concretely this represents things like x^2, y^4 but not xy
"""
var_idx: int # Index in State.variables, not variable!
power: Number
+ def __eq__(self, other: Self):
+ if self.var_idx != other.var_idx:
+ return False
+
+ if (self.power - other.power) > NUMERICS_EPS:
+ return False
+
+ return True
+
+ def __lt__(self, other: Self):
+ return self.var_idx < other.var_idx
+
@classmethod
@property
def constant(cls):
- """ Special index for constants.
+ """
+ Special index for constants.
Constants do not have an associated variable,
and the field power is ignored.
@@ -85,6 +100,7 @@ class PolyVarIndex(NamedTuple):
@staticmethod
def is_constant(index: Self) -> bool:
+ """ Check if is index of a constant term """
return index.var_idx == -1
@classmethod
@@ -94,7 +110,7 @@ class PolyVarIndex(NamedTuple):
class PolyIndex(tuple[PolyVarIndex]):
"""
- Tuple to index a coefficient of a polynomial.
+ Tuple to index a coefficient of a monomial in a polynomial.
Example:
@@ -117,7 +133,48 @@ class PolyIndex(tuple[PolyVarIndex]):
return f"{name}({indices})"
@classmethod
+ def from_dict(cls, d) -> Self:
+ return cls(PolyVarIndex(k, v) for k, v in d.items())
+
+ @classmethod
@property
- def constant(cls):
+ def constant(cls) -> Self:
""" Index of the constant term """
return cls((PolyVarIndex.constant,))
+
+ @staticmethod
+ def is_constant(index: Self) -> bool:
+ """ Check if is index of a constant term """
+ if len(index) != 1:
+ # FIXME: error message
+ raise IndexError(f"{index}")
+
+ return PolyVarIndex.is_constant(index[0])
+
+ @classmethod
+ def sort(cls, index: Self) -> Self:
+ """ Sort a tuple of indices """
+ return cls(sorted(index))
+
+ @classmethod
+ def product(cls, left: Self, right: Self) -> Self:
+ """
+ Compute index of product
+
+ For example if left is the index of xy and right is the index of y^2
+ this functions returns the index of xy^3
+ """
+ if cls.is_constant(left):
+ return right
+
+ if cls.is_constant(right):
+ return left
+
+ result: dict[int, Number] = dict(filterfalse(PolyVarIndex.is_constant, left))
+ for idx, power in filterfalse(PolyVarIndex.is_constant, right):
+ if idx not in result:
+ result[idx] = power
+ else:
+ result[idx] += power
+
+ return cls.sort(cls.from_dict(result))