aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-09 00:35:34 +0100
committerNao Pross <np@0hm.ch>2024-03-09 00:50:59 +0100
commit214a40c933eeb1fa0b98c570b2773e497b47a6e6 (patch)
tree9fbbd431446a9dd41c0a93acb61cf59f54133742
parentPass shape parameter to Repr constructor (diff)
downloadmdpoly-214a40c933eeb1fa0b98c570b2773e497b47a6e6.tar.gz
mdpoly-214a40c933eeb1fa0b98c570b2773e497b47a6e6.zip
Improve type signature of Expr.to_repr
Diffstat (limited to '')
-rw-r--r--mdpoly/abc.py12
-rw-r--r--mdpoly/algebra.py42
2 files changed, 32 insertions, 22 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 34cf090..245f3a2 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -10,7 +10,7 @@ from .util import iszero
if TYPE_CHECKING:
from .state import State
-from typing import Self, TypeVar, Generic, Any, Iterable, Sequence
+from typing import Self, Type, TypeVar, Generic, Any, Iterable, Sequence
from enum import Enum, auto
from copy import copy
from abc import ABC, abstractmethod
@@ -18,6 +18,12 @@ from dataclassabc import dataclassabc
from hashlib import sha256
+# Forward declaration of Repr class for TypeVar
+# Why is this shit necessary? Feels like writing C++
+class Repr: ...
+
+ReprT = TypeVar("ReprT", bound=Repr)
+
# ┏━╸╻ ╻┏━┓┏━┓┏━╸┏━┓┏━┓╻┏━┓┏┓╻┏━┓
# ┣╸ ┏╋┛┣━┛┣┳┛┣╸ ┗━┓┗━┓┃┃ ┃┃┗┫┗━┓
# ┗━╸╹ ╹╹ ╹┗╸┗━╸┗━┛┗━┛╹┗━┛╹ ╹┗━┛
@@ -102,7 +108,7 @@ class Expr(ABC):
# --- Methods for polynomial expression ---
@abstractmethod
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" Convert to a concrete representation
To convert an abstract object from the expression tree, we ... TODO: finish.
@@ -350,7 +356,7 @@ class Nothing(Leaf):
shape: Shape = Shape(0, 0)
algebra: Algebra = Algebra.none
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
raise ValueError("Nothing cannot be represented.")
def __repr__(self) -> str:
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index f4656d1..38373c0 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -9,13 +9,15 @@ from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
from typing import cast, Sequence, Iterable, Type, TypeVar
from functools import reduce
from itertools import product, chain, combinations_with_replacement
-from math import prod
from abc import abstractmethod
from dataclassabc import dataclassabc
import operator
+ReprT = TypeVar("ReprT", bound=Repr)
+
+
class BinaryOp(Expr):
""" Binary Operator """
def __init__(self, left, right):
@@ -76,7 +78,7 @@ class Reducible(Expr):
addition and multiplication.
"""
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
return self.reduce().to_repr(repr_type, state)
@@ -128,17 +130,19 @@ class PolyRingExpr(Expr):
returns a polynomial :math:`x^5 + x^3 + 1`.
"""
nonzero_exponents = filter(lambda e: e != 0, exponents)
- constant_term = PolyConst(1 if 0 in exponents else 0)
+ # Ignore typecheck because dataclassabc has no typing stub
+ constant_term = PolyConst(1 if 0 in exponents else 0) # type: ignore[call-arg]
# FIXME: remove annoying 0
return sum((variable ** e for e in nonzero_exponents), constant_term)
@classmethod
- def make_combinations(cls, variables: Iterable[Var], max_degree: int) -> Iterable[PolyRingExpr]:
+ def make_combinations(cls, variables: Iterable[PolyVar], max_degree: int) -> Iterable[PolyRingExpr]:
""" Make a list of """
- variables = chain(variables, (PolyConst(1),))
- for comb in combinations_with_replacement(variables, max_degree):
- yield prod(comb)
+ # Ignore typecheck because dataclassabc has no typing stub
+ vars_and_const = chain(variables, (PolyConst(1),)) # type: ignore[call-arg]
+ for comb in combinations_with_replacement(vars_and_const, max_degree):
+ yield cast(PolyRingExpr, reduce(operator.mul, comb))
# --- Operator Overloading ---
@@ -205,7 +209,7 @@ class PolyVar(Var, PolyRingExpr):
name: str # overloads Leaf.name
shape: Shape = Shape.scalar() # ovearloads PolyRingExpr.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
r = repr_type(self.shape)
idx = PolyVarIndex.from_var(self, state), # important comma!
r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
@@ -219,7 +223,7 @@ class PolyConst(Const, PolyRingExpr):
name: str = "" # overloads Leaf.name
shape: Shape = Shape.scalar() # ovearloads PolyRingExpr.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
r = repr_type(self.shape)
r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
return r, state
@@ -231,7 +235,7 @@ class PolyParam(Param, PolyRingExpr):
name: str # overloads Leaf.name
shape: Shape = Shape.scalar() # overloads PolyRingExpr.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
if self not in state.parameters:
raise MissingParameters("Cannot construct representation because "
f"value for parameter {self} was not given.")
@@ -242,7 +246,7 @@ class PolyParam(Param, PolyRingExpr):
class PolyAdd(BinaryOp, PolyRingExpr):
""" Addition operator between scalar polynomials. """
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
# Make a new empty representation
r = repr_type(self.shape)
@@ -277,7 +281,7 @@ class PolySub(BinaryOp, PolyRingExpr, Reducible):
class PolyMul(BinaryOp, PolyRingExpr):
""" Multiplication operator between scalar polynomials. """
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
r = repr_type(self.shape)
@@ -345,7 +349,7 @@ class PolyPartialDiff(UnaryOp, PolyRingExpr):
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
return self.inner.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
r = repr_type(self.shape)
lrepr, state = self.inner.to_repr(repr_type, state)
@@ -464,7 +468,7 @@ class MatConst(Const, MatrixExpr):
shape: Shape # overloads Expr.shape
name: str = "" # overloads Leaf.name
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
r = repr_type(self.shape)
for r, row in enumerate(self.value):
@@ -500,7 +504,7 @@ class MatVar(Var, MatrixExpr):
yield entry, var
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
r = repr_type(self.shape)
# FIXME: do not hardcode scalar type
@@ -522,7 +526,7 @@ class MatParam(Param, MatrixExpr):
name: str
shape: Shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
if self not in state.parameters:
raise MissingParameters("Cannot construct representation because "
f"value for parameter {self} was not given.")
@@ -546,7 +550,7 @@ class MatAdd(BinaryOp, MatrixExpr):
return self.left.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
# Make a new empty representation
r = repr_type(self.shape)
@@ -598,7 +602,7 @@ class MatElemMul(BinaryOp, MatrixExpr):
f"{self.left.shape} and {self.right.shape}")
return self.left.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
r = repr_type(self.shape)
@@ -638,7 +642,7 @@ class MatScalarMul(BinaryOp, MatrixExpr):
return self.right.shape
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
r = repr_type(self.shape)