summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/typing.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
index 43d7210..1de5c36 100644
--- a/polymatrix/polymatrix/typing.py
+++ b/polymatrix/polymatrix/typing.py
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import NamedTuple, Iterable
+from collections import UserDict
# TODO: remove these types, they are here for backward compatiblity
MonomialData = tuple[tuple[int, int], ...]
@@ -34,17 +35,10 @@ class MonomialIndex(tuple[VariableIndex]):
class PolyDict(dict[MonomialIndex, int | float]):
""" Polynomial, stored as a dictionary. """
- # NOTE: this class directly inherits from dict instead of
- # collections.UserDict, hence it is not possible to override any of the
- # __dunder__ methods. If you like to change behaviour of one of these
- # methods change to inherit from UserDict, this could incur a slight
- # performance cost, because dict is straight CPython and thus faster
-
@staticmethod
def empty() -> PolyDict:
return PolyDict({})
-
def terms(self) -> Iterable[tuple[MonomialIndex, int | float]]:
""" Iterate over terms with a non-zero coefficient. """
# This is an alias for readability
@@ -67,7 +61,15 @@ class MatrixIndex(NamedTuple):
col: int
-class PolyMatrixDict(dict[MatrixIndex, PolyDict]):
+class PolyMatrixDict(UserDict[MatrixIndex, PolyDict]):
""" Matrix whose entries are polynomials, stored as a dictionary. """
- # NOTE: Same as NOTE in PolyDict.
+ def __getitem__(self, key: tuple[int, int] | MatrixIndex) -> PolyDict:
+ if not isinstance(key, MatrixIndex):
+ key = MatrixIndex(*key)
+ return super().__getitem__(key)
+
+ def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: PolyDict):
+ if not isinstance(key, MatrixIndex):
+ key = MatrixIndex(*key)
+ return super().__setitem__(key, value)