diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/polymatrix/typing.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py index 1de5c36..4baff5a 100644 --- a/polymatrix/polymatrix/typing.py +++ b/polymatrix/polymatrix/typing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import NamedTuple, Iterable +from typing import NamedTuple, Iterable, cast from collections import UserDict # TODO: remove these types, they are here for backward compatiblity @@ -15,11 +15,14 @@ class VariableIndex(NamedTuple): variable: int # index in ExpressionState object power: int - def __lt__(self, other): """ Variables indices can be sorted with respect to variable index. """ return self.variable < other.variable + @staticmethod + def constant() -> VariableIndex: + return cast(VariableIndex, tuple()) + class MonomialIndex(tuple[VariableIndex]): """ @@ -30,15 +33,28 @@ class MonomialIndex(tuple[VariableIndex]): def empty() -> MonomialIndex: return MonomialIndex(tuple()) + @staticmethod + def constant() -> MonomialIndex: + return MonomialIndex((VariableIndex.constant(),)) -class PolyDict(dict[MonomialIndex, int | float]): +class PolyDict(UserDict[MonomialIndex, int | float]): """ Polynomial, stored as a dictionary. """ @staticmethod def empty() -> PolyDict: return PolyDict({}) + def __getitem__(self, key: Iterable[VariableIndex] | MonomialIndex) -> int | float: + if not isinstance(key, MonomialIndex): + key = MonomialIndex(key) + return super().__getitem__(key) + + def __setitem__(self, key: Iterable[VariableIndex] | MonomialIndex, value: int | float): + if not isinstance(key, MonomialIndex): + key = MonomialIndex(key) + return super().__setitem__(key, value) + def terms(self) -> Iterable[tuple[MonomialIndex, int | float]]: """ Iterate over terms with a non-zero coefficient. """ # This is an alias for readability @@ -64,12 +80,20 @@ class MatrixIndex(NamedTuple): class PolyMatrixDict(UserDict[MatrixIndex, PolyDict]): """ Matrix whose entries are polynomials, stored as a dictionary. """ + @staticmethod + def empty() -> PolyMatrixDict: + return PolyMatrixDict({}) + def __getitem__(self, key: tuple[int, int] | MatrixIndex) -> PolyDict: if not isinstance(key, MatrixIndex): key = MatrixIndex(*key) return super().__getitem__(key) - def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: PolyDict): + def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: dict | PolyDict): if not isinstance(key, MatrixIndex): key = MatrixIndex(*key) + + if not isinstance(value, PolyDict): + value = PolyDict(value) + return super().__setitem__(key, value) |