From e1589b318cf026f454cc6bc34a42c411dbb56047 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 15 Apr 2024 19:01:01 +0200 Subject: Make PolyMatrixDict accept indexing p[row, col] --- polymatrix/polymatrix/typing.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py index 43d7210..1de5c36 100644 --- a/polymatrix/polymatrix/typing.py +++ b/polymatrix/polymatrix/typing.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import NamedTuple, Iterable +from collections import UserDict # TODO: remove these types, they are here for backward compatiblity MonomialData = tuple[tuple[int, int], ...] @@ -34,17 +35,10 @@ class MonomialIndex(tuple[VariableIndex]): class PolyDict(dict[MonomialIndex, int | float]): """ Polynomial, stored as a dictionary. """ - # NOTE: this class directly inherits from dict instead of - # collections.UserDict, hence it is not possible to override any of the - # __dunder__ methods. If you like to change behaviour of one of these - # methods change to inherit from UserDict, this could incur a slight - # performance cost, because dict is straight CPython and thus faster - @staticmethod def empty() -> PolyDict: return PolyDict({}) - def terms(self) -> Iterable[tuple[MonomialIndex, int | float]]: """ Iterate over terms with a non-zero coefficient. """ # This is an alias for readability @@ -67,7 +61,15 @@ class MatrixIndex(NamedTuple): col: int -class PolyMatrixDict(dict[MatrixIndex, PolyDict]): +class PolyMatrixDict(UserDict[MatrixIndex, PolyDict]): """ Matrix whose entries are polynomials, stored as a dictionary. """ - # NOTE: Same as NOTE in PolyDict. + def __getitem__(self, key: tuple[int, int] | MatrixIndex) -> PolyDict: + if not isinstance(key, MatrixIndex): + key = MatrixIndex(*key) + return super().__getitem__(key) + + def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: PolyDict): + if not isinstance(key, MatrixIndex): + key = MatrixIndex(*key) + return super().__setitem__(key, value) -- cgit v1.2.1