diff options
Diffstat (limited to '')
-rw-r--r-- | mdpoly/types.py | 65 |
1 files changed, 61 insertions, 4 deletions
diff --git a/mdpoly/types.py b/mdpoly/types.py index b580b99..4b75e8f 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -1,7 +1,9 @@ from __future__ import annotations +from .constants import NUMERICS_EPS from .errors import InvalidShape +from itertools import filterfalse from typing import Self, NamedTuple, Iterable, TYPE_CHECKING if TYPE_CHECKING: @@ -68,15 +70,28 @@ class PolyVarIndex(NamedTuple): """ Tuple to index a power of a variable. - Concretely this represents things like x^2, y^4 + Concretely this represents things like x^2, y^4 but not xy """ var_idx: int # Index in State.variables, not variable! power: Number + def __eq__(self, other: Self): + if self.var_idx != other.var_idx: + return False + + if (self.power - other.power) > NUMERICS_EPS: + return False + + return True + + def __lt__(self, other: Self): + return self.var_idx < other.var_idx + @classmethod @property def constant(cls): - """ Special index for constants. + """ + Special index for constants. Constants do not have an associated variable, and the field power is ignored. @@ -85,6 +100,7 @@ class PolyVarIndex(NamedTuple): @staticmethod def is_constant(index: Self) -> bool: + """ Check if is index of a constant term """ return index.var_idx == -1 @classmethod @@ -94,7 +110,7 @@ class PolyVarIndex(NamedTuple): class PolyIndex(tuple[PolyVarIndex]): """ - Tuple to index a coefficient of a polynomial. + Tuple to index a coefficient of a monomial in a polynomial. Example: @@ -117,7 +133,48 @@ class PolyIndex(tuple[PolyVarIndex]): return f"{name}({indices})" @classmethod + def from_dict(cls, d) -> Self: + return cls(PolyVarIndex(k, v) for k, v in d.items()) + + @classmethod @property - def constant(cls): + def constant(cls) -> Self: """ Index of the constant term """ return cls((PolyVarIndex.constant,)) + + @staticmethod + def is_constant(index: Self) -> bool: + """ Check if is index of a constant term """ + if len(index) != 1: + # FIXME: error message + raise IndexError(f"{index}") + + return PolyVarIndex.is_constant(index[0]) + + @classmethod + def sort(cls, index: Self) -> Self: + """ Sort a tuple of indices """ + return cls(sorted(index)) + + @classmethod + def product(cls, left: Self, right: Self) -> Self: + """ + Compute index of product + + For example if left is the index of xy and right is the index of y^2 + this functions returns the index of xy^3 + """ + if cls.is_constant(left): + return right + + if cls.is_constant(right): + return left + + result: dict[int, Number] = dict(filterfalse(PolyVarIndex.is_constant, left)) + for idx, power in filterfalse(PolyVarIndex.is_constant, right): + if idx not in result: + result[idx] = power + else: + result[idx] += power + + return cls.sort(cls.from_dict(result)) |