diff options
-rw-r--r-- | mdpoly/abc.py | 7 | ||||
-rw-r--r-- | mdpoly/types.py | 7 | ||||
-rw-r--r-- | mdpoly/util.py | 9 |
3 files changed, 17 insertions, 6 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index a38a69d..b127e8b 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -2,6 +2,7 @@ from .types import Number, Shape, MatrixIndex, PolyIndex from .constants import NUMERICS_EPS +from .util import iszero from typing import Self, Sequence, Protocol, runtime_checkable from abc import abstractmethod @@ -71,11 +72,13 @@ class Repr(Protocol): def is_zero(self, entry: MatrixIndex, term: PolyIndex, eps: Number = NUMERICS_EPS) -> bool: """ Check if a polynomial coefficient is zero. """ - return self.at(entry, term) < eps + return iszero(self.at(entry, term), eps=eps) @abstractmethod def set_zero(self, entry: MatrixIndex, term: PolyIndex) -> None: """ Set a coefficient to zero (delete it). """ + # Default implementation for refence, you should do something better! + self.set(entry, term, 0.) @abstractmethod def entries(self) -> Sequence[MatrixIndex]: @@ -91,7 +94,7 @@ class Repr(Protocol): See also :py:data:`mdpoly.constants.NUMERICS_EPS`. """ for entry in self.entries(): for term in self.terms(entry): - if self.at(entry, term) < eps: + if self.is_zero(entry, term, eps=eps): self.set_zero(entry, term) def __iter__(self) -> Sequence[tuple[MatrixIndex, PolyIndex, Number]]: diff --git a/mdpoly/types.py b/mdpoly/types.py index 1605aa8..7a50551 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -1,8 +1,7 @@ from __future__ import annotations -from .constants import NUMERICS_EPS from .errors import InvalidShape -from .util import partition +from .util import partition, isclose from itertools import filterfalse from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING @@ -84,7 +83,7 @@ class PolyVarIndex(NamedTuple): if self.var_idx != other.var_idx: return False - if (self.power - other.power) > NUMERICS_EPS: + if not isclose(self.power, other.power): return False return True @@ -203,7 +202,7 @@ class PolyIndex(tuple[PolyVarIndex]): with_wrt_var, *_ = with_wrt_var # Take first, as should be only one # Check if is linear term - if (with_wrt_var.power - 1) < NUMERICS_EPS: + if isclose(with_wrt_var.power, 1.): return cls.sort(tuple(others) + cls.constant) # Decrease exponent diff --git a/mdpoly/util.py b/mdpoly/util.py index f1d5de2..ed7c234 100644 --- a/mdpoly/util.py +++ b/mdpoly/util.py @@ -1,4 +1,7 @@ +from .constants import NUMERICS_EPS + from itertools import tee, filterfalse +from math import abs def partition(pred, iterable): """Partition entries into false entries and true entries. @@ -10,3 +13,9 @@ def partition(pred, iterable): return filterfalse(pred, t1), filter(pred, t2) +def isclose(x: float, y: float, eps: float =NUMERICS_EPS) -> bool: + return abs(x - y) < eps + + +def iszero(x: float, eps: float =NUMERICS_EPS) -> bool: + return isclose(x, 0.) |