summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/polymatrix/typing.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
index 1de5c36..4baff5a 100644
--- a/polymatrix/polymatrix/typing.py
+++ b/polymatrix/polymatrix/typing.py
@@ -1,5 +1,5 @@
from __future__ import annotations
-from typing import NamedTuple, Iterable
+from typing import NamedTuple, Iterable, cast
from collections import UserDict
# TODO: remove these types, they are here for backward compatiblity
@@ -15,11 +15,14 @@ class VariableIndex(NamedTuple):
variable: int # index in ExpressionState object
power: int
-
def __lt__(self, other):
""" Variables indices can be sorted with respect to variable index. """
return self.variable < other.variable
+ @staticmethod
+ def constant() -> VariableIndex:
+ return cast(VariableIndex, tuple())
+
class MonomialIndex(tuple[VariableIndex]):
"""
@@ -30,15 +33,28 @@ class MonomialIndex(tuple[VariableIndex]):
def empty() -> MonomialIndex:
return MonomialIndex(tuple())
+ @staticmethod
+ def constant() -> MonomialIndex:
+ return MonomialIndex((VariableIndex.constant(),))
-class PolyDict(dict[MonomialIndex, int | float]):
+class PolyDict(UserDict[MonomialIndex, int | float]):
""" Polynomial, stored as a dictionary. """
@staticmethod
def empty() -> PolyDict:
return PolyDict({})
+ def __getitem__(self, key: Iterable[VariableIndex] | MonomialIndex) -> int | float:
+ if not isinstance(key, MonomialIndex):
+ key = MonomialIndex(key)
+ return super().__getitem__(key)
+
+ def __setitem__(self, key: Iterable[VariableIndex] | MonomialIndex, value: int | float):
+ if not isinstance(key, MonomialIndex):
+ key = MonomialIndex(key)
+ return super().__setitem__(key, value)
+
def terms(self) -> Iterable[tuple[MonomialIndex, int | float]]:
""" Iterate over terms with a non-zero coefficient. """
# This is an alias for readability
@@ -64,12 +80,20 @@ class MatrixIndex(NamedTuple):
class PolyMatrixDict(UserDict[MatrixIndex, PolyDict]):
""" Matrix whose entries are polynomials, stored as a dictionary. """
+ @staticmethod
+ def empty() -> PolyMatrixDict:
+ return PolyMatrixDict({})
+
def __getitem__(self, key: tuple[int, int] | MatrixIndex) -> PolyDict:
if not isinstance(key, MatrixIndex):
key = MatrixIndex(*key)
return super().__getitem__(key)
- def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: PolyDict):
+ def __setitem__(self, key: tuple[int, int] | MatrixIndex, value: dict | PolyDict):
if not isinstance(key, MatrixIndex):
key = MatrixIndex(*key)
+
+ if not isinstance(value, PolyDict):
+ value = PolyDict(value)
+
return super().__setitem__(key, value)