aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-04 02:16:25 +0100
committerNao Pross <np@0hm.ch>2024-03-04 02:26:43 +0100
commit257ccc86c49fb2e45ff0c46a54882245b5056bed (patch)
tree83ccb4901ad62b9b9a6a2cd2a90055bbb9c88778
parentAdd shape check for MatScalarMul (diff)
downloadmdpoly-257ccc86c49fb2e45ff0c46a54882245b5056bed.tar.gz
mdpoly-257ccc86c49fb2e45ff0c46a54882245b5056bed.zip
Fix types (or, make mypy happier)
-rw-r--r--mdpoly/abc.py24
-rw-r--r--mdpoly/algebra.py76
-rw-r--r--mdpoly/leaves.py19
-rw-r--r--mdpoly/representations.py12
-rw-r--r--mdpoly/types.py32
5 files changed, 85 insertions, 78 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 991cc04..c291fdc 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -4,9 +4,8 @@ from .types import Number, Shape, MatrixIndex, PolyIndex
from .constants import NUMERICS_EPS
from .util import iszero
-from typing import Self, Sequence, Protocol, runtime_checkable
+from typing import Self, Iterable, Sequence, Protocol, runtime_checkable
from abc import abstractmethod
-from functools import cached_property
@runtime_checkable
@@ -32,7 +31,7 @@ class Expr(Protocol):
left: Self | Leaf
right: Self | Leaf
- @cached_property
+ @property
@abstractmethod
def shape(self) -> Shape:
""" Computes the shape of the expression. """
@@ -41,11 +40,11 @@ class Expr(Protocol):
name = self.__class__.__qualname__
return f"{name}(left={self.left}, right={self.right})"
- def children(self) -> Sequence[Self]:
+ def children(self) -> Sequence[Self | Leaf]:
""" Iterate over the two nodes """
return self.left, self.right
- def leaves(self) -> Sequence[Leaf]:
+ def leaves(self) -> Iterable[Leaf]:
""" Returns the leaves of the tree. This is done recursively and in
:math:`\mathcal{O}(n)`."""
if isinstance(self.left, Leaf):
@@ -102,10 +101,17 @@ class Expr(Protocol):
return replace_all(self)
- def __iter__(self) -> Sequence[Self | Leaf]:
+ def __iter__(self) -> Iterable[Self | Leaf]:
yield from self.children()
+@runtime_checkable
+class Rel(Protocol):
+ """ Relation between two expressions. """
+ lhs: Expr
+ rhs: Expr
+
+
class Repr(Protocol):
r""" Representation of a multivariate matrix polynomial expression.
@@ -140,11 +146,11 @@ class Repr(Protocol):
self.set(entry, term, 0.)
@abstractmethod
- def entries(self) -> Sequence[MatrixIndex]:
+ def entries(self) -> Iterable[MatrixIndex]:
""" Return indices to non-zero entries of the matrix. """
@abstractmethod
- def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]:
+ def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]:
""" Return indices to non-zero terms in the polynomial at the given
matrix entry. """
@@ -156,7 +162,7 @@ class Repr(Protocol):
if self.is_zero(entry, term, eps=eps):
self.set_zero(entry, term)
- def __iter__(self) -> Sequence[tuple[MatrixIndex, PolyIndex, Number]]:
+ def __iter__(self) -> Iterable[tuple[MatrixIndex, PolyIndex, Number]]:
""" Iterate over non-zero entries of the representations. """
for entry in self.entries():
for term in self.terms(entry):
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index ef3f7eb..a9fc806 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -7,7 +7,7 @@ from .state import State
from .types import Shape, MatrixIndex, PolyIndex
from .representations import HasRepr
-from typing import Protocol, TypeVar, Any, runtime_checkable
+from typing import Protocol, Self, Any, runtime_checkable
from functools import wraps, reduce, cached_property
from itertools import product
from enum import Enum
@@ -16,8 +16,6 @@ from abc import abstractmethod
import operator
-T = TypeVar("T")
-
@runtime_checkable
class AlgebraicStructure(Protocol):
""" Provides methods to enforce algebraic closure of operations.
@@ -37,19 +35,19 @@ class AlgebraicStructure(Protocol):
_constant: type
@classmethod
- def _is_constant(cls: T, other: T) -> bool:
+ def _is_constant(cls, other: Self) -> bool:
return isinstance(other, cls._constant)
@classmethod
- def _is_parameter(cls: T, other: T) -> bool:
+ def _is_parameter(cls, other: Self) -> bool:
return isinstance(other, cls._parameter)
@classmethod
- def _is_const_or_param(cls: T, other: T) -> bool:
+ def _is_const_or_param(cls, other: Self) -> bool:
return cls._is_constant(other) or cls._is_parameter(other)
@classmethod
- def _assert_same_algebra(cls: T, other: T) -> None:
+ def _assert_same_algebra(cls, other: Self) -> None:
if not cls._algebra == other._algebra:
if not cls._is_constant(other):
raise AlgebraicError("Cannot perform operation between types from "
@@ -57,7 +55,7 @@ class AlgebraicStructure(Protocol):
f"and {type(other)} ({other._algebra})")
@classmethod
- def _wrap_if_constant(cls: T, other: Any):
+ def _wrap_if_constant(cls, other: Any):
if isinstance(other, AlgebraicStructure):
cls._assert_same_algebra(other)
return other
@@ -75,6 +73,10 @@ class AlgebraicStructure(Protocol):
# f"objects with different shapes {cls.shape} and {other.shape}")
+class ExprWithRepr(Expr, HasRepr, Protocol):
+ """ Expression that has a representation """
+
+
class ReducibleExpr(HasRepr, Protocol):
""" Reducible Expression
@@ -84,12 +86,12 @@ class ReducibleExpr(HasRepr, Protocol):
addition and multiplication.
"""
- def to_repr(self, repr_type: type, state: State) -> Repr:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr` """
return self.reduce().to_repr(repr_type, state)
@abstractmethod
- def reduce(self) -> Expr:
+ def reduce(self) -> ExprWithRepr:
""" Reduce the expression to its basic elements """
@@ -145,12 +147,11 @@ def unary_operator(inner_type: AlgebraicStructure):
init_cls(self, *args, **kwargs)
self.left, self.right = left, Nothing
- @property
def inner(self):
return self.left
cls.__init__ = new_init_cls
- cls.inner = inner
+ cls.inner = property(inner)
return cls
return decorator
@@ -165,11 +166,10 @@ class Add(Expr, HasRepr):
raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
- def to_repr(self, repr_type: type, state: State) -> Repr:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
# Make a new empty representation
r = repr_type()
- entry = MatrixIndex.scalar
# Make representations for existing stuff
for node in self.children():
@@ -197,10 +197,10 @@ class Sub(Expr, ReducibleExpr):
raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
- def reduce(self) -> Expr:
+ def reduce(self) -> ExprWithRepr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
- return self.left + (-1 * self.right)
- # return Add(self.left, Mul(self._constant(value=-1), self.right))
+ # return self.left + (-1 * self.right)
+ return Add(self.left, Mul(self._constant(value=-1), self.right))
def __repr__(self) -> str:
return f"({self.left} - {self.right})"
@@ -216,21 +216,25 @@ class Mul(Expr, HasRepr):
raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
- def to_repr(self, repr_type: type, state: State) -> Repr:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
r = repr_type()
lrepr, state = self.left.to_repr(repr_type, state)
rrepr, state = self.right.to_repr(repr_type, state)
- entry = MatrixIndex.scalar
- for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
- # Compute where the results should go
- term = PolyIndex.product(lterm, rterm)
+ # Non zero entries are the intersection since if either is zero the
+ # result is zero
+ nonzero_entries = set(lrepr.entries()) & set(rrepr.entries())
+ for entry in nonzero_entries:
+ # Compute polynomial product between non-zero entries
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
- # Compute product
- p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
- r.set(entry, term, p)
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
return r, state
@@ -248,7 +252,7 @@ class Exp(Expr, ReducibleExpr):
raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.")
return self.left.shape
- def reduce(self) -> Expr:
+ def reduce(self) -> ExprWithRepr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
var = self.left
ntimes = self.right.value - 1
@@ -268,9 +272,9 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol):
This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the
polynomials are scalars.
"""
- _algebra = AlgebraicStructure.Algebra.poly_ring
- _parameter = Param
- _constant = Const
+ _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.poly_ring
+ _parameter: type = Param
+ _constant: type = Const
def __add__(self, other):
other = self._wrap_if_constant(other)
@@ -333,7 +337,7 @@ class PolyMul(Mul, PolyRingAlgebra):
class PolyDiv(Expr, ReducibleExpr, PolyRingAlgebra):
""" Division of scalar polynomial by scalar. """
- def reduce(self) -> Expr:
+ def reduce(self) -> ExprWithRepr:
""" See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """
inverse = self._constant(value=1/self.right.value)
return PolyMul(inverse, self.left)
@@ -358,12 +362,12 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra):
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
return self.inner.shape
- def to_repr(self, repr_type: type, state: State) -> Repr:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """
r = repr_type()
lrepr, state = self.left.to_repr(repr_type, state)
- entry = MatrixIndex.scalar
+ entry = MatrixIndex.scalar()
wrt = state.index(self.wrt)
for term in lrepr.terms(entry):
@@ -408,10 +412,10 @@ class MatrixAlgebra(AlgebraicStructure, Protocol):
already included (eg. transposition).
"""
- _algebra = AlgebraicStructure.Algebra.matrix_ring
+ _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.matrix_ring
# FIXME: consider MatParam or something like that?
- _parameter = Param
- _constant = Const
+ _parameter: type = Param
+ _constant: type = Const
def __add__(self, other):
other = self._wrap_if_constant(other)
@@ -504,7 +508,7 @@ class MatMul(Expr, MatrixAlgebra):
@property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
- ...
+ raise NotImplementedError
def __repr__(self) -> str:
return f"({self.left} @ {self.right})"
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
index 14562a9..dbfd332 100644
--- a/mdpoly/leaves.py
+++ b/mdpoly/leaves.py
@@ -4,11 +4,8 @@ from .state import State
from .errors import MissingParameters
from .representations import HasRepr
-from typing import TypeVar
from dataclasses import dataclass
-T = TypeVar("T", bound=Repr)
-
@dataclass(frozen=True)
class Nothing(Leaf):
@@ -26,11 +23,11 @@ class Const(Leaf, HasRepr):
"""
value: Number
name: str = ""
- shape: Shape = Shape.scalar
+ shape: Shape = Shape.scalar()
- def to_repr(self, repr_type: type[T], state: State) -> T:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
- r.set(MatrixIndex.scalar, PolyIndex.constant, self.value)
+ r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
return r, state
def __repr__(self) -> str:
@@ -46,12 +43,12 @@ class Var(Leaf, HasRepr):
Polynomial variable
"""
name: str
- shape: Shape = Shape.scalar
+ shape: Shape = Shape.scalar()
- def to_repr(self, repr_type: type[T], state: State) -> T:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
idx = PolyVarIndex.from_var(self, state),
- r.set(MatrixIndex.scalar, PolyIndex(idx), 1)
+ r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
return r, state
def __repr__(self) -> str:
@@ -64,9 +61,9 @@ class Param(Leaf, HasRepr):
A parameter
"""
name: str
- shape: Shape = Shape.scalar
+ shape: Shape = Shape.scalar()
- def to_repr(self, repr_type: type[T], state: State) -> T:
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
if self not in state.parameters:
raise MissingParameters("Cannot construct representation because "
f"value for parameter {self} was not given.")
diff --git a/mdpoly/representations.py b/mdpoly/representations.py
index da038bd..6aa3088 100644
--- a/mdpoly/representations.py
+++ b/mdpoly/representations.py
@@ -2,7 +2,7 @@ from .abc import Repr
from .types import Number, Shape, MatrixIndex, PolyIndex
from .state import State
-from typing import Protocol, Sequence
+from typing import Protocol, Sequence, Iterable
from abc import abstractmethod
import numpy as np
@@ -50,11 +50,11 @@ class SparseRepr(Repr):
if not self.data[entry]:
del self.data[entry]
- def entries(self) -> Sequence[MatrixIndex]:
+ def entries(self) -> Iterable[MatrixIndex]:
""" Return indices to non-zero entries of the matrix. """
yield from self.data.keys()
- def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]:
+ def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]:
""" Return indices to terms with a non-zero coefficient in the
polynomial at the given matrix entry. """
if entry in self.data.keys():
@@ -97,7 +97,7 @@ class SparseMatrixRepr(Repr):
class DenseRepr(Repr):
""" Dense representation of a polynomial that uses Numpy arrays. """
- data: npt.NDArray[Number]
+ data: npt.NDArray
def __init__(self, shape: Shape, dtype:type =float):
self.data = np.zeros(shape, dtype)
@@ -110,11 +110,11 @@ class DenseRepr(Repr):
""" Set value of polynomial entry """
raise NotImplementedError
- def entries(self) -> Sequence[MatrixIndex]:
+ def entries(self) -> Iterable[MatrixIndex]:
""" Return indices to non-zero entries of the matrix """
raise NotImplementedError
- def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]:
+ def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]:
""" Return indices to non-zero terms in the polynomial at the given
matrix entry """
raise NotImplementedError
diff --git a/mdpoly/types.py b/mdpoly/types.py
index cb0213e..46c5882 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -23,13 +23,11 @@ class Shape(NamedTuple):
rows: int
cols: int
- @classmethod
@property
- def infer(cls) -> Self:
- return cls(-1, -1)
+ def infer(self) -> Self:
+ return self.__class__(-1, -1)
@classmethod
- @property
def scalar(cls) -> Self:
return cls(1, 1)
@@ -52,12 +50,10 @@ class MatrixIndex(NamedTuple):
col: int
@classmethod
- @property
def infer(cls, row=-1, col=-1) -> Self:
return cls(row, col)
@classmethod
- @property
def scalar(cls):
""" Shorthand for index of a scalar """
return cls(row=0, col=0)
@@ -79,7 +75,10 @@ class PolyVarIndex(NamedTuple):
""" Make an index from a variable object. """
return cls(var_idx=state.index(variable), power=power)
- def __eq__(self, other: Self):
+ def __eq__(self, other):
+ if type(other) is not PolyVarIndex:
+ other = PolyVarIndex(other)
+
if self.var_idx != other.var_idx:
return False
@@ -88,12 +87,14 @@ class PolyVarIndex(NamedTuple):
return True
- def __lt__(self, other: Self):
+ def __lt__(self, other):
+ if type(other) is not PolyVarIndex:
+ other = PolyVarIndex(other)
+
return self.var_idx < other.var_idx
@classmethod
- @property
- def constant(cls):
+ def constant(cls) -> Self:
""" Special index for constants.
Constants do not have an associated variable, and the field power is
@@ -102,7 +103,7 @@ class PolyVarIndex(NamedTuple):
return cls(var_idx=-1, power=0)
@staticmethod
- def is_constant(index: Self) -> bool:
+ def is_constant(index: PolyVarIndex) -> bool:
""" Check if is index of a constant term. """
return index.var_idx == -1
@@ -141,13 +142,12 @@ class PolyIndex(tuple[PolyVarIndex]):
return cls(PolyVarIndex(k, v) for k, v in d.items())
@classmethod
- @property
def constant(cls) -> Self:
""" Index of the constant term. """
- return cls((PolyVarIndex.constant,))
+ return cls((PolyVarIndex.constant(),))
@staticmethod
- def is_constant(index: Self) -> bool:
+ def is_constant(index: PolyIndex) -> bool:
""" Check if is index of a constant term. """
if len(index) != 1:
return False
@@ -155,7 +155,7 @@ class PolyIndex(tuple[PolyVarIndex]):
return PolyVarIndex.is_constant(index[0])
@classmethod
- def sort(cls, index: Self) -> Self:
+ def sort(cls, index: tuple | Self) -> Self:
""" Sort a tuple of indices. """
return cls(sorted(index))
@@ -203,7 +203,7 @@ class PolyIndex(tuple[PolyVarIndex]):
# Check if is linear term
if isclose(with_wrt_var.power, 1.):
- return cls.sort(tuple(others) + cls.constant)
+ return cls.sort(tuple(others) + cls.constant())
# Decrease exponent
new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1))