aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/abc.py7
-rw-r--r--mdpoly/types.py7
-rw-r--r--mdpoly/util.py9
3 files changed, 17 insertions, 6 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index a38a69d..b127e8b 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -2,6 +2,7 @@
from .types import Number, Shape, MatrixIndex, PolyIndex
from .constants import NUMERICS_EPS
+from .util import iszero
from typing import Self, Sequence, Protocol, runtime_checkable
from abc import abstractmethod
@@ -71,11 +72,13 @@ class Repr(Protocol):
def is_zero(self, entry: MatrixIndex, term: PolyIndex, eps: Number = NUMERICS_EPS) -> bool:
""" Check if a polynomial coefficient is zero. """
- return self.at(entry, term) < eps
+ return iszero(self.at(entry, term), eps=eps)
@abstractmethod
def set_zero(self, entry: MatrixIndex, term: PolyIndex) -> None:
""" Set a coefficient to zero (delete it). """
+ # Default implementation for refence, you should do something better!
+ self.set(entry, term, 0.)
@abstractmethod
def entries(self) -> Sequence[MatrixIndex]:
@@ -91,7 +94,7 @@ class Repr(Protocol):
See also :py:data:`mdpoly.constants.NUMERICS_EPS`. """
for entry in self.entries():
for term in self.terms(entry):
- if self.at(entry, term) < eps:
+ if self.is_zero(entry, term, eps=eps):
self.set_zero(entry, term)
def __iter__(self) -> Sequence[tuple[MatrixIndex, PolyIndex, Number]]:
diff --git a/mdpoly/types.py b/mdpoly/types.py
index 1605aa8..7a50551 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -1,8 +1,7 @@
from __future__ import annotations
-from .constants import NUMERICS_EPS
from .errors import InvalidShape
-from .util import partition
+from .util import partition, isclose
from itertools import filterfalse
from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING
@@ -84,7 +83,7 @@ class PolyVarIndex(NamedTuple):
if self.var_idx != other.var_idx:
return False
- if (self.power - other.power) > NUMERICS_EPS:
+ if not isclose(self.power, other.power):
return False
return True
@@ -203,7 +202,7 @@ class PolyIndex(tuple[PolyVarIndex]):
with_wrt_var, *_ = with_wrt_var # Take first, as should be only one
# Check if is linear term
- if (with_wrt_var.power - 1) < NUMERICS_EPS:
+ if isclose(with_wrt_var.power, 1.):
return cls.sort(tuple(others) + cls.constant)
# Decrease exponent
diff --git a/mdpoly/util.py b/mdpoly/util.py
index f1d5de2..ed7c234 100644
--- a/mdpoly/util.py
+++ b/mdpoly/util.py
@@ -1,4 +1,7 @@
+from .constants import NUMERICS_EPS
+
from itertools import tee, filterfalse
+from math import abs
def partition(pred, iterable):
"""Partition entries into false entries and true entries.
@@ -10,3 +13,9 @@ def partition(pred, iterable):
return filterfalse(pred, t1), filter(pred, t2)
+def isclose(x: float, y: float, eps: float =NUMERICS_EPS) -> bool:
+ return abs(x - y) < eps
+
+
+def iszero(x: float, eps: float =NUMERICS_EPS) -> bool:
+ return isclose(x, 0.)