diff options
author | Nao Pross <np@0hm.ch> | 2024-03-07 12:32:44 +0100 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-03-07 12:32:44 +0100 |
commit | 99e36685b023fc55b95678009a0d6d872f25e3aa (patch) | |
tree | cb3ed6e249fdfac07a52b702034ff0bbf1a1ff88 | |
parent | Add missing docstrings (diff) | |
download | mdpoly-99e36685b023fc55b95678009a0d6d872f25e3aa.tar.gz mdpoly-99e36685b023fc55b95678009a0d6d872f25e3aa.zip |
Fix regression of algebra check
Feature broke on commit cdfac0a6a8a9e4225a03aeb6e7115f6fc2f2fbad while
switching from structural typing to OOP inheritance
Diffstat (limited to '')
-rw-r--r-- | mdpoly/__init__.py | 14 | ||||
-rw-r--r-- | mdpoly/abc.py | 20 | ||||
-rw-r--r-- | mdpoly/algebra.py | 27 | ||||
-rw-r--r-- | mdpoly/leaves.py | 79 | ||||
-rw-r--r-- | poetry.lock | 13 | ||||
-rw-r--r-- | pyproject.toml | 1 |
6 files changed, 131 insertions, 23 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index fddd62b..0c1576d 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -3,14 +3,14 @@ # internal classes imported with underscore because # they should not be exposed to the end users -from .abc import (Shape as _Shape) +from .abc import (Algebra as _Algebra, Shape as _Shape) from .algebra import (PolyRingExpr as _PolyRingExpr, MatrixExpr as _MatrixExpr) -from .leaves import (Const as _Const, - Var as _Var, - Param as _Param) +from .leaves import (Const as _Const, MatConst as _MatConst, + Var as _Var, MatVar as _MatVar, + Param as _Param, MatParam as _MatParam) from .state import State as _State @@ -51,15 +51,15 @@ class Parameter(_Param, _PolyRingExpr): """ Parameter that can be substituted """ -class MatrixConstant(_Const, _PolyRingExpr): +class MatrixConstant(_MatConst, _PolyRingExpr): """ Matrix constant """ -class MatrixVariable(_Var, _MatrixExpr): +class MatrixVariable(_MatVar, _MatrixExpr): """ Matrix Polynomial Variable """ -class MatrixParameter(_Param, _MatrixExpr): +class MatrixParameter(_MatParam, _MatrixExpr): """ Matrix Parameter """ diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 488d89f..99d9484 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -3,6 +3,7 @@ from __future__ import annotations from .index import Number, Shape, MatrixIndex, PolyIndex from .constants import NUMERICS_EPS +from .errors import AlgebraicError from .state import State from .util import iszero @@ -153,23 +154,34 @@ class Expr(ABC): # --- Operator overloading --- @staticmethod - def _wrap(if_type: type, wrapper_type: type, obj: Any) -> Expr: + def _assert_same_algebra(left: Expr, right: Expr) -> None: + if not isinstance(left, Expr) or not isinstance(right, Expr): + return + + if left.algebra != right.algebra: + raise AlgebraicError("Cannot perform algebraic operation between " + f"{left} and {right} because they have different " + f"algebras {left.algebra} and {right.algebra}.") + + @staticmethod + def _wrap(if_type: type, wrapper_type: type, obj: Any, *args, **kwargs) -> Expr: """ Wrap non-expr objects. Suppose ``x`` is of type *Expr*, then we would like to be able to do things like ``x + 1``, so this function can be called in operator overloadings to for instance wrapp the 1 into a :py:class:`mdpoly.leaves.Const`. If ``obj`` is already of type *Expr*, - this function does nothing. + this function does nothing. The arguments *args* and *kwargs* are + forwarded to the constructor of *wrapper_type*. """ # Do not wrap if is alreay an expression if isinstance(obj, Expr): return obj if not isinstance(obj, if_type): - raise TypeError + raise TypeError(f"Cannot wrap {obj} with type {wrapper_type} because it is not of type {if_type}.") - return wrapper_type(obj) + return wrapper_type(obj, *args, **kwargs) def __add__(self, other: Any) -> Self: raise NotImplementedError diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index 1efbabf..580c2e2 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -2,7 +2,7 @@ from __future__ import annotations from .abc import Algebra, Expr, Repr -from .leaves import Nothing, Const, Param, Var +from .leaves import Nothing, Const, Param, Var, MatConst from .errors import AlgebraicError, InvalidShape from .state import State from .index import Shape, MatrixIndex, PolyIndex, Number @@ -94,8 +94,8 @@ class PolyRingExpr(Expr): """ @property - def algebra(self): - return + def algebra(self) -> Algebra: + return Algebra.poly_ring @property def shape(self) -> Shape: @@ -106,18 +106,22 @@ class PolyRingExpr(Expr): return self.left.shape def __add__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolyAdd(self, other) def __radd__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolyAdd(other, self) def __sub__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolySub(self, other) def __rsub__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolyAdd(other, self) @@ -125,10 +129,12 @@ class PolyRingExpr(Expr): return PolyMul(self._constant(-1), self) def __mul__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolyMul(self, other) def __rmul__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return PolyMul(other, self) @@ -140,7 +146,7 @@ class PolyRingExpr(Expr): def __truediv__(self, other): other = self._wrap(Number, Const, other) - if not self._is_const_or_param(other): + if not isinstance(other, Const | Param): raise AlgebraicError("Cannot divide by variables in polynomial ring.") return PolyMul(self, other) @@ -313,34 +319,41 @@ class MatrixExpr(Expr): """ @property - def algebra(self): - return Algebra.matrix_algebra + def algebra(self) -> Algebra: + return Algebra.matrix_ring def __add__(self, other): - other = self._wrap(Number, Const, other) + self._assert_same_algebra(self, other) + other = self._wrap(Number, MatConst, other) return MatAdd(self, other) def __sub__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatSub(self, other) def __rsub__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatSub(other, self) def __mul__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatScalarMul(other, self) def __rmul__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatScalarMul(other, self) def __matmul__(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatMul(self, other) def __rmatmul(self, other): + self._assert_same_algebra(self, other) other = self._wrap(Number, Const, other) return MatMul(other, self) diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index 90c303e..fe2be49 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -4,6 +4,8 @@ from .state import State from .errors import MissingParameters from dataclasses import dataclass +from dataclassabc import dataclassabc +from typing import Iterable @dataclass(frozen=True) @@ -22,6 +24,11 @@ class Nothing(Leaf): return "Nothing" +# ┏━┓┏━╸┏━┓╻ ┏━┓┏━┓ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓ +# ┗━┓┃ ┣━┫┃ ┣━┫┣┳┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓ +# ┗━┛┗━╸╹ ╹┗━╸╹ ╹╹┗╸ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛ + + @dataclass(frozen=True) class Const(Leaf): """ @@ -30,7 +37,7 @@ class Const(Leaf): value: Number name: str = "" shape: Shape = Shape.scalar() - algebra: Algebra = Algebra.none + algebra: Algebra = Algebra.poly_ring def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() @@ -39,7 +46,7 @@ class Const(Leaf): def __repr__(self) -> str: if not self.name: - return str(self.value) + return repr(self.value) return self.name @@ -51,7 +58,7 @@ class Var(Leaf): """ name: str shape: Shape = Shape.scalar() - algebra: Algebra = Algebra.none + algebra: Algebra = Algebra.poly_ring def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: r = repr_type() @@ -70,7 +77,7 @@ class Param(Leaf): """ name: str shape: Shape = Shape.scalar() - algebra: Algebra = Algebra.none + algebra: Algebra = Algebra.poly_ring def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: if self not in state.parameters: @@ -81,3 +88,67 @@ class Param(Leaf): def __repr__(self) -> str: return self.name + + +# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓ +# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓ +# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛ + + +@dataclassabc(frozen=True) +class MatConst(Leaf): + """ + A matrix constant + """ + value: Iterable[Iterable[Number]] + shape: Shape + name: str = "" + algebra: Algebra = Algebra.matrix_ring + + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + raise UserWarning("MatConst.to_repr is not implemented!") + return repr_type(), state + + def __repr__(self) -> str: + if not self.name: + return repr(self.value) + + return self.name + + +@dataclassabc(frozen=True) +class MatVar(Leaf): + """ + Matrix polynomial variable + """ + name: str + shape: Shape + algebra: Algebra = Algebra.matrix_ring + + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + raise UserWarning("MatVar.to_repr is not implemented!") + return repr_type(), state + + def __repr__(self) -> str: + return self.name + + +@dataclassabc(frozen=True) +class MatParam(Leaf): + """ + Matrix parameter + """ + name: str + shape: Shape + algebra: Algebra = Algebra.poly_ring + + def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]: + if self not in state.parameters: + raise MissingParameters("Cannot construct representation because " + f"value for parameter {self} was not given.") + + # FIXME: add conversion to scalar variables + return MatConst(state.parameters[self]).to_repr(repr_type, state) + + def __repr__(self) -> str: + return self.name diff --git a/poetry.lock b/poetry.lock index b519924..b3409d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -427,6 +427,17 @@ traitlets = ">=4" test = ["pytest"] [[package]] +name = "dataclass-abc" +version = "0.0.8" +description = "Library that lets you define abstract properties for dataclasses." +optional = false +python-versions = ">=3.10" +files = [ + {file = "dataclass-abc-0.0.8.tar.gz", hash = "sha256:f93f93c5e30d982af39539bbc6b83fda0a8158c9fd30e2b2e016992825073f29"}, + {file = "dataclass_abc-0.0.8-py3-none-any.whl", hash = "sha256:817ebd5b83e9853129061faca432d31a3756580ab9d3df18405cda63769462ec"}, +] + +[[package]] name = "debugpy" version = "1.8.1" description = "An implementation of the Debug Adapter Protocol for Python" @@ -2364,4 +2375,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "c69dc3809703fd94763f142c9f748e8a2bcb82b34db98742591f072de73c2902" +content-hash = "304dd7d25bb710afad9207207882b98e34dd3cff3a7599358864a66ebf23bdd6" diff --git a/pyproject.toml b/pyproject.toml index d204db0..7084516 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ readme = "README.md" python = ">=3.10,<4.0" numpy = "^1.26.4" scipy = "^1.12.0" +dataclass-abc = "^0.0.8" [tool.poetry.group.notebook.dependencies] |