diff options
author | Nao Pross <np@0hm.ch> | 2024-03-09 00:35:34 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-09 00:50:59 +0100 |
commit | 214a40c933eeb1fa0b98c570b2773e497b47a6e6 (patch) | |
tree | 9fbbd431446a9dd41c0a93acb61cf59f54133742 | |
parent | Pass shape parameter to Repr constructor (diff) | |
download | mdpoly-214a40c933eeb1fa0b98c570b2773e497b47a6e6.tar.gz mdpoly-214a40c933eeb1fa0b98c570b2773e497b47a6e6.zip |
Improve type signature of Expr.to_repr
Diffstat (limited to '')
-rw-r--r-- | mdpoly/abc.py | 12 | ||||
-rw-r--r-- | mdpoly/algebra.py | 42 |
2 files changed, 32 insertions, 22 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 34cf090..245f3a2 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -10,7 +10,7 @@ from .util import iszero if TYPE_CHECKING: from .state import State -from typing import Self, TypeVar, Generic, Any, Iterable, Sequence +from typing import Self, Type, TypeVar, Generic, Any, Iterable, Sequence from enum import Enum, auto from copy import copy from abc import ABC, abstractmethod @@ -18,6 +18,12 @@ from dataclassabc import dataclassabc from hashlib import sha256 +# Forward declaration of Repr class for TypeVar +# Why is this shit necessary? Feels like writing C++ +class Repr: ... + +ReprT = TypeVar("ReprT", bound=Repr) + # ┏━╸╻ ╻┏━┓┏━┓┏━╸┏━┓┏━┓╻┏━┓┏┓╻┏━┓ # ┣╸ ┏╋┛┣━┛┣┳┛┣╸ ┗━┓┗━┓┃┃ ┃┃┗┫┗━┓ # ┗━╸╹ ╹╹ ╹┗╸┗━╸┗━┛┗━┛╹┗━┛╹ ╹┗━┛ @@ -102,7 +108,7 @@ class Expr(ABC): # --- Methods for polynomial expression --- @abstractmethod - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ Convert to a concrete representation To convert an abstract object from the expression tree, we ... TODO: finish. @@ -350,7 +356,7 @@ class Nothing(Leaf): shape: Shape = Shape(0, 0) algebra: Algebra = Algebra.none - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: raise ValueError("Nothing cannot be represented.") def __repr__(self) -> str: diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index f4656d1..38373c0 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -9,13 +9,15 @@ from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number from typing import cast, Sequence, Iterable, Type, TypeVar from functools import reduce from itertools import product, chain, combinations_with_replacement -from math import prod from abc import abstractmethod from dataclassabc import dataclassabc import operator +ReprT = TypeVar("ReprT", bound=Repr) + + class BinaryOp(Expr): """ Binary Operator """ def __init__(self, left, right): @@ -76,7 +78,7 @@ class Reducible(Expr): addition and multiplication. """ - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ return self.reduce().to_repr(repr_type, state) @@ -128,17 +130,19 @@ class PolyRingExpr(Expr): returns a polynomial :math:`x^5 + x^3 + 1`. """ nonzero_exponents = filter(lambda e: e != 0, exponents) - constant_term = PolyConst(1 if 0 in exponents else 0) + # Ignore typecheck because dataclassabc has no typing stub + constant_term = PolyConst(1 if 0 in exponents else 0) # type: ignore[call-arg] # FIXME: remove annoying 0 return sum((variable ** e for e in nonzero_exponents), constant_term) @classmethod - def make_combinations(cls, variables: Iterable[Var], max_degree: int) -> Iterable[PolyRingExpr]: + def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyRingExpr]: """ Make a list of """ - variables = chain(variables, (PolyConst(1),)) - for comb in combinations_with_replacement(variables, max_degree): - yield prod(comb) + # Ignore typecheck because dataclassabc has no typing stub + vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg] + for comb in combinations_with_replacement(vars_and_const, max_degree): + yield cast(PolyRingExpr, reduce(operator.mul, comb)) # --- Operator Overloading --- @@ -205,7 +209,7 @@ class PolyVar(Var, PolyRingExpr): name: str # overloads Leaf.name shape: Shape = Shape.scalar() # ovearloads PolyRingExpr.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: r = repr_type(self.shape) idx = PolyVarIndex.from_var(self, state), # important comma! r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) @@ -219,7 +223,7 @@ class PolyConst(Const, PolyRingExpr): name: str = "" # overloads Leaf.name shape: Shape = Shape.scalar() # ovearloads PolyRingExpr.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: r = repr_type(self.shape) r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) return r, state @@ -231,7 +235,7 @@ class PolyParam(Param, PolyRingExpr): name: str # overloads Leaf.name shape: Shape = Shape.scalar() # overloads PolyRingExpr.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: if self not in state.parameters: raise MissingParameters("Cannot construct representation because " f"value for parameter {self} was not given.") @@ -242,7 +246,7 @@ class PolyParam(Param, PolyRingExpr): class PolyAdd(BinaryOp, PolyRingExpr): """ Addition operator between scalar polynomials. """ - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ # Make a new empty representation r = repr_type(self.shape) @@ -277,7 +281,7 @@ class PolySub(BinaryOp, PolyRingExpr, Reducible): class PolyMul(BinaryOp, PolyRingExpr): """ Multiplication operator between scalar polynomials. """ - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ r = repr_type(self.shape) @@ -345,7 +349,7 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr): """ See :py:meth:`mdpoly.abc.Expr.shape`. """ return self.inner.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ r = repr_type(self.shape) lrepr, state = self.inner.to_repr(repr_type, state) @@ -464,7 +468,7 @@ class MatConst(Const, MatrixExpr): shape: Shape # overloads Expr.shape name: str = "" # overloads Leaf.name - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: r = repr_type(self.shape) for r, row in enumerate(self.value): @@ -500,7 +504,7 @@ class MatVar(Var, MatrixExpr): yield entry, var - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: r = repr_type(self.shape) # FIXME: do not hardcode scalar type @@ -522,7 +526,7 @@ class MatParam(Param, MatrixExpr): name: str shape: Shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: if self not in state.parameters: raise MissingParameters("Cannot construct representation because " f"value for parameter {self} was not given.") @@ -546,7 +550,7 @@ class MatAdd(BinaryOp, MatrixExpr): return self.left.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ # Make a new empty representation r = repr_type(self.shape) @@ -598,7 +602,7 @@ class MatElemMul(BinaryOp, MatrixExpr): f"{self.left.shape} and {self.right.shape}") return self.left.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ r = repr_type(self.shape) @@ -638,7 +642,7 @@ class MatScalarMul(BinaryOp, MatrixExpr): return self.right.shape - def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """ r = repr_type(self.shape) |