diff options
author | Nao Pross <np@0hm.ch> | 2024-03-04 02:16:25 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-04 02:26:43 +0100 |
commit | 257ccc86c49fb2e45ff0c46a54882245b5056bed (patch) | |
tree | 83ccb4901ad62b9b9a6a2cd2a90055bbb9c88778 | |
parent | Add shape check for MatScalarMul (diff) | |
download | mdpoly-257ccc86c49fb2e45ff0c46a54882245b5056bed.tar.gz mdpoly-257ccc86c49fb2e45ff0c46a54882245b5056bed.zip |
Fix types (or, make mypy happier)
Diffstat (limited to '')
-rw-r--r-- | mdpoly/abc.py | 24 | ||||
-rw-r--r-- | mdpoly/algebra.py | 76 | ||||
-rw-r--r-- | mdpoly/leaves.py | 19 | ||||
-rw-r--r-- | mdpoly/representations.py | 12 | ||||
-rw-r--r-- | mdpoly/types.py | 32 |
5 files changed, 85 insertions, 78 deletions
diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 991cc04..c291fdc 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -4,9 +4,8 @@ from .types import Number, Shape, MatrixIndex, PolyIndex from .constants import NUMERICS_EPS from .util import iszero -from typing import Self, Sequence, Protocol, runtime_checkable +from typing import Self, Iterable, Sequence, Protocol, runtime_checkable from abc import abstractmethod -from functools import cached_property @runtime_checkable @@ -32,7 +31,7 @@ class Expr(Protocol): left: Self | Leaf right: Self | Leaf - @cached_property + @property @abstractmethod def shape(self) -> Shape: """ Computes the shape of the expression. """ @@ -41,11 +40,11 @@ class Expr(Protocol): name = self.__class__.__qualname__ return f"{name}(left={self.left}, right={self.right})" - def children(self) -> Sequence[Self]: + def children(self) -> Sequence[Self | Leaf]: """ Iterate over the two nodes """ return self.left, self.right - def leaves(self) -> Sequence[Leaf]: + def leaves(self) -> Iterable[Leaf]: """ Returns the leaves of the tree. This is done recursively and in :math:`\mathcal{O}(n)`.""" if isinstance(self.left, Leaf): @@ -102,10 +101,17 @@ class Expr(Protocol): return replace_all(self) - def __iter__(self) -> Sequence[Self | Leaf]: + def __iter__(self) -> Iterable[Self | Leaf]: yield from self.children() +@runtime_checkable +class Rel(Protocol): + """ Relation between two expressions. """ + lhs: Expr + rhs: Expr + + class Repr(Protocol): r""" Representation of a multivariate matrix polynomial expression. @@ -140,11 +146,11 @@ class Repr(Protocol): self.set(entry, term, 0.) @abstractmethod - def entries(self) -> Sequence[MatrixIndex]: + def entries(self) -> Iterable[MatrixIndex]: """ Return indices to non-zero entries of the matrix. """ @abstractmethod - def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]: + def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]: """ Return indices to non-zero terms in the polynomial at the given matrix entry. """ @@ -156,7 +162,7 @@ class Repr(Protocol): if self.is_zero(entry, term, eps=eps): self.set_zero(entry, term) - def __iter__(self) -> Sequence[tuple[MatrixIndex, PolyIndex, Number]]: + def __iter__(self) -> Iterable[tuple[MatrixIndex, PolyIndex, Number]]: """ Iterate over non-zero entries of the representations. """ for entry in self.entries(): for term in self.terms(entry): diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index ef3f7eb..a9fc806 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -7,7 +7,7 @@ from .state import State from .types import Shape, MatrixIndex, PolyIndex from .representations import HasRepr -from typing import Protocol, TypeVar, Any, runtime_checkable +from typing import Protocol, Self, Any, runtime_checkable from functools import wraps, reduce, cached_property from itertools import product from enum import Enum @@ -16,8 +16,6 @@ from abc import abstractmethod import operator -T = TypeVar("T") - @runtime_checkable class AlgebraicStructure(Protocol): """ Provides methods to enforce algebraic closure of operations. @@ -37,19 +35,19 @@ class AlgebraicStructure(Protocol): _constant: type @classmethod - def _is_constant(cls: T, other: T) -> bool: + def _is_constant(cls, other: Self) -> bool: return isinstance(other, cls._constant) @classmethod - def _is_parameter(cls: T, other: T) -> bool: + def _is_parameter(cls, other: Self) -> bool: return isinstance(other, cls._parameter) @classmethod - def _is_const_or_param(cls: T, other: T) -> bool: + def _is_const_or_param(cls, other: Self) -> bool: return cls._is_constant(other) or cls._is_parameter(other) @classmethod - def _assert_same_algebra(cls: T, other: T) -> None: + def _assert_same_algebra(cls, other: Self) -> None: if not cls._algebra == other._algebra: if not cls._is_constant(other): raise AlgebraicError("Cannot perform operation between types from " @@ -57,7 +55,7 @@ class AlgebraicStructure(Protocol): f"and {type(other)} ({other._algebra})") @classmethod - def _wrap_if_constant(cls: T, other: Any): + def _wrap_if_constant(cls, other: Any): if isinstance(other, AlgebraicStructure): cls._assert_same_algebra(other) return other @@ -75,6 +73,10 @@ class AlgebraicStructure(Protocol): # f"objects with different shapes {cls.shape} and {other.shape}") +class ExprWithRepr(Expr, HasRepr, Protocol): + """ Expression that has a representation """ + + class ReducibleExpr(HasRepr, Protocol): """ Reducible Expression @@ -84,12 +86,12 @@ class ReducibleExpr(HasRepr, Protocol): addition and multiplication. """ - def to_repr(self, repr_type: type, state: State) -> Repr: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr` """ return self.reduce().to_repr(repr_type, state) @abstractmethod - def reduce(self) -> Expr: + def reduce(self) -> ExprWithRepr: """ Reduce the expression to its basic elements """ @@ -145,12 +147,11 @@ def unary_operator(inner_type: AlgebraicStructure): init_cls(self, *args, **kwargs) self.left, self.right = left, Nothing - @property def inner(self): return self.left cls.__init__ = new_init_cls - cls.inner = inner + cls.inner = property(inner) return cls return decorator @@ -165,11 +166,10 @@ class Add(Expr, HasRepr): raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") return self.left.shape - def to_repr(self, repr_type: type, state: State) -> Repr: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ # Make a new empty representation r = repr_type() - entry = MatrixIndex.scalar # Make representations for existing stuff for node in self.children(): @@ -197,10 +197,10 @@ class Sub(Expr, ReducibleExpr): raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") return self.left.shape - def reduce(self) -> Expr: + def reduce(self) -> ExprWithRepr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ - return self.left + (-1 * self.right) - # return Add(self.left, Mul(self._constant(value=-1), self.right)) + # return self.left + (-1 * self.right) + return Add(self.left, Mul(self._constant(value=-1), self.right)) def __repr__(self) -> str: return f"({self.left} - {self.right})" @@ -216,21 +216,25 @@ class Mul(Expr, HasRepr): raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") return self.left.shape - def to_repr(self, repr_type: type, state: State) -> Repr: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ r = repr_type() lrepr, state = self.left.to_repr(repr_type, state) rrepr, state = self.right.to_repr(repr_type, state) - entry = MatrixIndex.scalar - for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): - # Compute where the results should go - term = PolyIndex.product(lterm, rterm) + # Non zero entries are the intersection since if either is zero the + # result is zero + nonzero_entries = set(lrepr.entries()) & set(rrepr.entries()) + for entry in nonzero_entries: + # Compute polynomial product between non-zero entries + for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)): + # Compute where the results should go + term = PolyIndex.product(lterm, rterm) - # Compute product - p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) - r.set(entry, term, p) + # Compute product + p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm) + r.set(entry, term, p) return r, state @@ -248,7 +252,7 @@ class Exp(Expr, ReducibleExpr): raise InvalidShape(f"Cannot add shapes {self.left.shape} and {self.right.shape}.") return self.left.shape - def reduce(self) -> Expr: + def reduce(self) -> ExprWithRepr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ var = self.left ntimes = self.right.value - 1 @@ -268,9 +272,9 @@ class PolyRingAlgebra(AlgebraicStructure, Protocol): This is the algebra of :math:`\mathbb{R}[x_1, \ldots, x_n]`. Note that the polynomials are scalars. """ - _algebra = AlgebraicStructure.Algebra.poly_ring - _parameter = Param - _constant = Const + _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.poly_ring + _parameter: type = Param + _constant: type = Const def __add__(self, other): other = self._wrap_if_constant(other) @@ -333,7 +337,7 @@ class PolyMul(Mul, PolyRingAlgebra): class PolyDiv(Expr, ReducibleExpr, PolyRingAlgebra): """ Division of scalar polynomial by scalar. """ - def reduce(self) -> Expr: + def reduce(self) -> ExprWithRepr: """ See :py:meth:`mdpoly.algebra.ReducibleExpr.reduce`. """ inverse = self._constant(value=1/self.right.value) return PolyMul(inverse, self.left) @@ -358,12 +362,12 @@ class PartialDiff(Expr, HasRepr, PolyRingAlgebra): """ See :py:meth:`mdpoly.abc.Expr.shape`. """ return self.inner.shape - def to_repr(self, repr_type: type, state: State) -> Repr: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: """ See :py:meth:`mdpoly.representations.HasRepr.to_repr`. """ r = repr_type() lrepr, state = self.left.to_repr(repr_type, state) - entry = MatrixIndex.scalar + entry = MatrixIndex.scalar() wrt = state.index(self.wrt) for term in lrepr.terms(entry): @@ -408,10 +412,10 @@ class MatrixAlgebra(AlgebraicStructure, Protocol): already included (eg. transposition). """ - _algebra = AlgebraicStructure.Algebra.matrix_ring + _algebra: AlgebraicStructure.Algebra = AlgebraicStructure.Algebra.matrix_ring # FIXME: consider MatParam or something like that? - _parameter = Param - _constant = Const + _parameter: type = Param + _constant: type = Const def __add__(self, other): other = self._wrap_if_constant(other) @@ -504,7 +508,7 @@ class MatMul(Expr, MatrixAlgebra): @property def shape(self) -> Shape: """ See :py:meth:`mdpoly.abc.Expr.shape`. """ - ... + raise NotImplementedError def __repr__(self) -> str: return f"({self.left} @ {self.right})" diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index 14562a9..dbfd332 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -4,11 +4,8 @@ from .state import State from .errors import MissingParameters from .representations import HasRepr -from typing import TypeVar from dataclasses import dataclass -T = TypeVar("T", bound=Repr) - @dataclass(frozen=True) class Nothing(Leaf): @@ -26,11 +23,11 @@ class Const(Leaf, HasRepr): """ value: Number name: str = "" - shape: Shape = Shape.scalar + shape: Shape = Shape.scalar() - def to_repr(self, repr_type: type[T], state: State) -> T: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() - r.set(MatrixIndex.scalar, PolyIndex.constant, self.value) + r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) return r, state def __repr__(self) -> str: @@ -46,12 +43,12 @@ class Var(Leaf, HasRepr): Polynomial variable """ name: str - shape: Shape = Shape.scalar + shape: Shape = Shape.scalar() - def to_repr(self, repr_type: type[T], state: State) -> T: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() idx = PolyVarIndex.from_var(self, state), - r.set(MatrixIndex.scalar, PolyIndex(idx), 1) + r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) return r, state def __repr__(self) -> str: @@ -64,9 +61,9 @@ class Param(Leaf, HasRepr): A parameter """ name: str - shape: Shape = Shape.scalar + shape: Shape = Shape.scalar() - def to_repr(self, repr_type: type[T], state: State) -> T: + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: if self not in state.parameters: raise MissingParameters("Cannot construct representation because " f"value for parameter {self} was not given.") diff --git a/mdpoly/representations.py b/mdpoly/representations.py index da038bd..6aa3088 100644 --- a/mdpoly/representations.py +++ b/mdpoly/representations.py @@ -2,7 +2,7 @@ from .abc import Repr from .types import Number, Shape, MatrixIndex, PolyIndex from .state import State -from typing import Protocol, Sequence +from typing import Protocol, Sequence, Iterable from abc import abstractmethod import numpy as np @@ -50,11 +50,11 @@ class SparseRepr(Repr): if not self.data[entry]: del self.data[entry] - def entries(self) -> Sequence[MatrixIndex]: + def entries(self) -> Iterable[MatrixIndex]: """ Return indices to non-zero entries of the matrix. """ yield from self.data.keys() - def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]: + def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]: """ Return indices to terms with a non-zero coefficient in the polynomial at the given matrix entry. """ if entry in self.data.keys(): @@ -97,7 +97,7 @@ class SparseMatrixRepr(Repr): class DenseRepr(Repr): """ Dense representation of a polynomial that uses Numpy arrays. """ - data: npt.NDArray[Number] + data: npt.NDArray def __init__(self, shape: Shape, dtype:type =float): self.data = np.zeros(shape, dtype) @@ -110,11 +110,11 @@ class DenseRepr(Repr): """ Set value of polynomial entry """ raise NotImplementedError - def entries(self) -> Sequence[MatrixIndex]: + def entries(self) -> Iterable[MatrixIndex]: """ Return indices to non-zero entries of the matrix """ raise NotImplementedError - def terms(self, entry: MatrixIndex) -> Sequence[PolyIndex]: + def terms(self, entry: MatrixIndex) -> Iterable[PolyIndex]: """ Return indices to non-zero terms in the polynomial at the given matrix entry """ raise NotImplementedError diff --git a/mdpoly/types.py b/mdpoly/types.py index cb0213e..46c5882 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -23,13 +23,11 @@ class Shape(NamedTuple): rows: int cols: int - @classmethod @property - def infer(cls) -> Self: - return cls(-1, -1) + def infer(self) -> Self: + return self.__class__(-1, -1) @classmethod - @property def scalar(cls) -> Self: return cls(1, 1) @@ -52,12 +50,10 @@ class MatrixIndex(NamedTuple): col: int @classmethod - @property def infer(cls, row=-1, col=-1) -> Self: return cls(row, col) @classmethod - @property def scalar(cls): """ Shorthand for index of a scalar """ return cls(row=0, col=0) @@ -79,7 +75,10 @@ class PolyVarIndex(NamedTuple): """ Make an index from a variable object. """ return cls(var_idx=state.index(variable), power=power) - def __eq__(self, other: Self): + def __eq__(self, other): + if type(other) is not PolyVarIndex: + other = PolyVarIndex(other) + if self.var_idx != other.var_idx: return False @@ -88,12 +87,14 @@ class PolyVarIndex(NamedTuple): return True - def __lt__(self, other: Self): + def __lt__(self, other): + if type(other) is not PolyVarIndex: + other = PolyVarIndex(other) + return self.var_idx < other.var_idx @classmethod - @property - def constant(cls): + def constant(cls) -> Self: """ Special index for constants. Constants do not have an associated variable, and the field power is @@ -102,7 +103,7 @@ class PolyVarIndex(NamedTuple): return cls(var_idx=-1, power=0) @staticmethod - def is_constant(index: Self) -> bool: + def is_constant(index: PolyVarIndex) -> bool: """ Check if is index of a constant term. """ return index.var_idx == -1 @@ -141,13 +142,12 @@ class PolyIndex(tuple[PolyVarIndex]): return cls(PolyVarIndex(k, v) for k, v in d.items()) @classmethod - @property def constant(cls) -> Self: """ Index of the constant term. """ - return cls((PolyVarIndex.constant,)) + return cls((PolyVarIndex.constant(),)) @staticmethod - def is_constant(index: Self) -> bool: + def is_constant(index: PolyIndex) -> bool: """ Check if is index of a constant term. """ if len(index) != 1: return False @@ -155,7 +155,7 @@ class PolyIndex(tuple[PolyVarIndex]): return PolyVarIndex.is_constant(index[0]) @classmethod - def sort(cls, index: Self) -> Self: + def sort(cls, index: tuple | Self) -> Self: """ Sort a tuple of indices. """ return cls(sorted(index)) @@ -203,7 +203,7 @@ class PolyIndex(tuple[PolyVarIndex]): # Check if is linear term if isclose(with_wrt_var.power, 1.): - return cls.sort(tuple(others) + cls.constant) + return cls.sort(tuple(others) + cls.constant()) # Decrease exponent new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1)) |