aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-07 12:32:44 +0100
committerNao Pross <np@0hm.ch>2024-03-07 12:32:44 +0100
commit99e36685b023fc55b95678009a0d6d872f25e3aa (patch)
treecb3ed6e249fdfac07a52b702034ff0bbf1a1ff88
parentAdd missing docstrings (diff)
downloadmdpoly-99e36685b023fc55b95678009a0d6d872f25e3aa.tar.gz
mdpoly-99e36685b023fc55b95678009a0d6d872f25e3aa.zip
Fix regression of algebra check
Feature broke on commit cdfac0a6a8a9e4225a03aeb6e7115f6fc2f2fbad while switching from structural typing to OOP inheritance
-rw-r--r--mdpoly/__init__.py14
-rw-r--r--mdpoly/abc.py20
-rw-r--r--mdpoly/algebra.py27
-rw-r--r--mdpoly/leaves.py79
-rw-r--r--poetry.lock13
-rw-r--r--pyproject.toml1
6 files changed, 131 insertions, 23 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index fddd62b..0c1576d 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -3,14 +3,14 @@
# internal classes imported with underscore because
# they should not be exposed to the end users
-from .abc import (Shape as _Shape)
+from .abc import (Algebra as _Algebra, Shape as _Shape)
from .algebra import (PolyRingExpr as _PolyRingExpr,
MatrixExpr as _MatrixExpr)
-from .leaves import (Const as _Const,
- Var as _Var,
- Param as _Param)
+from .leaves import (Const as _Const, MatConst as _MatConst,
+ Var as _Var, MatVar as _MatVar,
+ Param as _Param, MatParam as _MatParam)
from .state import State as _State
@@ -51,15 +51,15 @@ class Parameter(_Param, _PolyRingExpr):
""" Parameter that can be substituted """
-class MatrixConstant(_Const, _PolyRingExpr):
+class MatrixConstant(_MatConst, _PolyRingExpr):
""" Matrix constant """
-class MatrixVariable(_Var, _MatrixExpr):
+class MatrixVariable(_MatVar, _MatrixExpr):
""" Matrix Polynomial Variable """
-class MatrixParameter(_Param, _MatrixExpr):
+class MatrixParameter(_MatParam, _MatrixExpr):
""" Matrix Parameter """
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index 488d89f..99d9484 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -3,6 +3,7 @@ from __future__ import annotations
from .index import Number, Shape, MatrixIndex, PolyIndex
from .constants import NUMERICS_EPS
+from .errors import AlgebraicError
from .state import State
from .util import iszero
@@ -153,23 +154,34 @@ class Expr(ABC):
# --- Operator overloading ---
@staticmethod
- def _wrap(if_type: type, wrapper_type: type, obj: Any) -> Expr:
+ def _assert_same_algebra(left: Expr, right: Expr) -> None:
+ if not isinstance(left, Expr) or not isinstance(right, Expr):
+ return
+
+ if left.algebra != right.algebra:
+ raise AlgebraicError("Cannot perform algebraic operation between "
+ f"{left} and {right} because they have different "
+ f"algebras {left.algebra} and {right.algebra}.")
+
+ @staticmethod
+ def _wrap(if_type: type, wrapper_type: type, obj: Any, *args, **kwargs) -> Expr:
""" Wrap non-expr objects.
Suppose ``x`` is of type *Expr*, then we would like to be able to do
things like ``x + 1``, so this function can be called in operator
overloadings to for instance wrapp the 1 into a
:py:class:`mdpoly.leaves.Const`. If ``obj`` is already of type *Expr*,
- this function does nothing.
+ this function does nothing. The arguments *args* and *kwargs* are
+ forwarded to the constructor of *wrapper_type*.
"""
# Do not wrap if is alreay an expression
if isinstance(obj, Expr):
return obj
if not isinstance(obj, if_type):
- raise TypeError
+ raise TypeError(f"Cannot wrap {obj} with type {wrapper_type} because it is not of type {if_type}.")
- return wrapper_type(obj)
+ return wrapper_type(obj, *args, **kwargs)
def __add__(self, other: Any) -> Self:
raise NotImplementedError
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index 1efbabf..580c2e2 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -2,7 +2,7 @@
from __future__ import annotations
from .abc import Algebra, Expr, Repr
-from .leaves import Nothing, Const, Param, Var
+from .leaves import Nothing, Const, Param, Var, MatConst
from .errors import AlgebraicError, InvalidShape
from .state import State
from .index import Shape, MatrixIndex, PolyIndex, Number
@@ -94,8 +94,8 @@ class PolyRingExpr(Expr):
"""
@property
- def algebra(self):
- return
+ def algebra(self) -> Algebra:
+ return Algebra.poly_ring
@property
def shape(self) -> Shape:
@@ -106,18 +106,22 @@ class PolyRingExpr(Expr):
return self.left.shape
def __add__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolyAdd(self, other)
def __radd__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolyAdd(other, self)
def __sub__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolySub(self, other)
def __rsub__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolyAdd(other, self)
@@ -125,10 +129,12 @@ class PolyRingExpr(Expr):
return PolyMul(self._constant(-1), self)
def __mul__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolyMul(self, other)
def __rmul__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return PolyMul(other, self)
@@ -140,7 +146,7 @@ class PolyRingExpr(Expr):
def __truediv__(self, other):
other = self._wrap(Number, Const, other)
- if not self._is_const_or_param(other):
+ if not isinstance(other, Const | Param):
raise AlgebraicError("Cannot divide by variables in polynomial ring.")
return PolyMul(self, other)
@@ -313,34 +319,41 @@ class MatrixExpr(Expr):
"""
@property
- def algebra(self):
- return Algebra.matrix_algebra
+ def algebra(self) -> Algebra:
+ return Algebra.matrix_ring
def __add__(self, other):
- other = self._wrap(Number, Const, other)
+ self._assert_same_algebra(self, other)
+ other = self._wrap(Number, MatConst, other)
return MatAdd(self, other)
def __sub__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatSub(self, other)
def __rsub__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatSub(other, self)
def __mul__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatScalarMul(other, self)
def __rmul__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatScalarMul(other, self)
def __matmul__(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatMul(self, other)
def __rmatmul(self, other):
+ self._assert_same_algebra(self, other)
other = self._wrap(Number, Const, other)
return MatMul(other, self)
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
index 90c303e..fe2be49 100644
--- a/mdpoly/leaves.py
+++ b/mdpoly/leaves.py
@@ -4,6 +4,8 @@ from .state import State
from .errors import MissingParameters
from dataclasses import dataclass
+from dataclassabc import dataclassabc
+from typing import Iterable
@dataclass(frozen=True)
@@ -22,6 +24,11 @@ class Nothing(Leaf):
return "Nothing"
+# ┏━┓┏━╸┏━┓╻ ┏━┓┏━┓ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓
+# ┗━┓┃ ┣━┫┃ ┣━┫┣┳┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓
+# ┗━┛┗━╸╹ ╹┗━╸╹ ╹╹┗╸ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛
+
+
@dataclass(frozen=True)
class Const(Leaf):
"""
@@ -30,7 +37,7 @@ class Const(Leaf):
value: Number
name: str = ""
shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.none
+ algebra: Algebra = Algebra.poly_ring
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
@@ -39,7 +46,7 @@ class Const(Leaf):
def __repr__(self) -> str:
if not self.name:
- return str(self.value)
+ return repr(self.value)
return self.name
@@ -51,7 +58,7 @@ class Var(Leaf):
"""
name: str
shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.none
+ algebra: Algebra = Algebra.poly_ring
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
r = repr_type()
@@ -70,7 +77,7 @@ class Param(Leaf):
"""
name: str
shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.none
+ algebra: Algebra = Algebra.poly_ring
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
if self not in state.parameters:
@@ -81,3 +88,67 @@ class Param(Leaf):
def __repr__(self) -> str:
return self.name
+
+
+# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓
+# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓
+# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛
+
+
+@dataclassabc(frozen=True)
+class MatConst(Leaf):
+ """
+ A matrix constant
+ """
+ value: Iterable[Iterable[Number]]
+ shape: Shape
+ name: str = ""
+ algebra: Algebra = Algebra.matrix_ring
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ raise UserWarning("MatConst.to_repr is not implemented!")
+ return repr_type(), state
+
+ def __repr__(self) -> str:
+ if not self.name:
+ return repr(self.value)
+
+ return self.name
+
+
+@dataclassabc(frozen=True)
+class MatVar(Leaf):
+ """
+ Matrix polynomial variable
+ """
+ name: str
+ shape: Shape
+ algebra: Algebra = Algebra.matrix_ring
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ raise UserWarning("MatVar.to_repr is not implemented!")
+ return repr_type(), state
+
+ def __repr__(self) -> str:
+ return self.name
+
+
+@dataclassabc(frozen=True)
+class MatParam(Leaf):
+ """
+ Matrix parameter
+ """
+ name: str
+ shape: Shape
+ algebra: Algebra = Algebra.poly_ring
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ if self not in state.parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {self} was not given.")
+
+ # FIXME: add conversion to scalar variables
+ return MatConst(state.parameters[self]).to_repr(repr_type, state)
+
+ def __repr__(self) -> str:
+ return self.name
diff --git a/poetry.lock b/poetry.lock
index b519924..b3409d2 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -427,6 +427,17 @@ traitlets = ">=4"
test = ["pytest"]
[[package]]
+name = "dataclass-abc"
+version = "0.0.8"
+description = "Library that lets you define abstract properties for dataclasses."
+optional = false
+python-versions = ">=3.10"
+files = [
+ {file = "dataclass-abc-0.0.8.tar.gz", hash = "sha256:f93f93c5e30d982af39539bbc6b83fda0a8158c9fd30e2b2e016992825073f29"},
+ {file = "dataclass_abc-0.0.8-py3-none-any.whl", hash = "sha256:817ebd5b83e9853129061faca432d31a3756580ab9d3df18405cda63769462ec"},
+]
+
+[[package]]
name = "debugpy"
version = "1.8.1"
description = "An implementation of the Debug Adapter Protocol for Python"
@@ -2364,4 +2375,4 @@ test = ["websockets"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
-content-hash = "c69dc3809703fd94763f142c9f748e8a2bcb82b34db98742591f072de73c2902"
+content-hash = "304dd7d25bb710afad9207207882b98e34dd3cff3a7599358864a66ebf23bdd6"
diff --git a/pyproject.toml b/pyproject.toml
index d204db0..7084516 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,6 +10,7 @@ readme = "README.md"
python = ">=3.10,<4.0"
numpy = "^1.26.4"
scipy = "^1.12.0"
+dataclass-abc = "^0.0.8"
[tool.poetry.group.notebook.dependencies]