aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-15 18:09:02 +0100
committerNao Pross <np@0hm.ch>2024-03-15 18:22:23 +0100
commit6b0db138a5cfb1d6c61ab60002873673aab811e1 (patch)
treed1670d843eebf26d56e69fbadc0f6db1c6b947cf
parentFix reference docs (diff)
downloadmdpoly-6b0db138a5cfb1d6c61ab60002873673aab811e1.tar.gz
mdpoly-6b0db138a5cfb1d6c61ab60002873673aab811e1.zip
Make expression objects dataclasses
-rw-r--r--mdpoly/algebra.py73
1 files changed, 29 insertions, 44 deletions
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index c51113e..4c66765 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -10,60 +10,36 @@ from typing import cast, Sequence, Iterable, Type, TypeVar, Self
from functools import reduce
from itertools import product, chain, combinations_with_replacement
from abc import abstractmethod
-from dataclassabc import dataclassabc
-
+from dataclasses import field, dataclass
import operator
+from dataclassabc import dataclassabc
-class BinaryOp(Expr):
- """ Binary Operator. TODO: desc """
- def __init__(self, left, right):
- self._left = left
- self._right = right
-
- @property
- def left(self) -> Expr:
- return self._left
-
- @left.setter
- def left(self, left: Expr) -> None:
- self._left = left
-
- @property
- def right(self) -> Expr:
- return self._right
- @right.setter
- def right(self, right: Expr) -> None:
- self._right = right
+@dataclassabc(eq=False)
+class BinaryOp(Expr):
+ left: Expr = field()
+ right: Expr = field()
+@dataclassabc(eq=False)
class UnaryOp(Expr):
- """ Unary Operator. TODO: desc. """
- def __init__(self, left, right=None):
- self._inner = left
+ right: Expr = field()
@property
def inner(self) -> Expr:
- return self._inner
-
- @property
- def left(self) -> Expr:
- return self._inner
-
- @left.setter
- def left(self, left: Expr) -> None:
- self._inner = left
+ """ Inner expression on which the operator is acting, alias for right. """
+ return self.right
@property
- def right(self) -> Expr:
+ def left(self) -> Expr:
return Nothing()
- @right.setter
- def right(self, right) -> None:
- if not isinstance(right, Nothing):
- raise ValueError("Cannot set right of left-acting unary opertator "
- "to somethig that is not of type Nothing.")
+ @left.setter
+ def left(self, left) -> None:
+ if not isinstance(left, Nothing):
+ raise ValueError("Cannot set left of left-acting unary operator "
+ "to something that is not of type Nothing.")
class Reducible(Expr):
@@ -250,6 +226,7 @@ class PolyParam(Param, PolyRingExpr):
return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]
+@dataclass(eq=False)
class PolyAdd(BinaryOp, PolyRingExpr):
""" Addition operator between scalar polynomials. """
@@ -274,6 +251,7 @@ class PolyAdd(BinaryOp, PolyRingExpr):
return f"({self.left} + {self.right})"
+@dataclass(eq=False)
class PolySub(BinaryOp, PolyRingExpr, Reducible):
""" Subtraction operator between scalar polynomials. """
@@ -285,6 +263,7 @@ class PolySub(BinaryOp, PolyRingExpr, Reducible):
return f"({self.left} - {self.right})"
+@dataclass(eq=False)
class PolyMul(BinaryOp, PolyRingExpr):
""" Multiplication operator between scalar polynomials. """
@@ -314,6 +293,7 @@ class PolyMul(BinaryOp, PolyRingExpr):
return f"({self.left} * {self.right})"
+@dataclass(eq=False)
class PolyExp(BinaryOp, PolyRingExpr, Reducible):
""" Exponentiation operator between scalar polynomials. """
@@ -344,12 +324,10 @@ class PolyExp(BinaryOp, PolyRingExpr, Reducible):
return f"({self.left} ** {self.right})"
+@dataclassabc(eq=False)
class PolyPartialDiff(UnaryOp, PolyRingExpr):
""" Partial differentiation of scalar polynomials. """
-
- def __init__(self, inner: Expr, with_respect_to: PolyVar):
- UnaryOp.__init__(self, inner)
- self.wrt = with_respect_to
+ wrt: PolyVar
@property
def shape(self) -> Shape:
@@ -558,6 +536,7 @@ class MatParam(Param, MatrixExpr):
return self.name
+@dataclassabc
class MatAdd(BinaryOp, MatrixExpr):
""" Addition operator between matrices. """
@@ -591,6 +570,7 @@ class MatAdd(BinaryOp, MatrixExpr):
return f"({self.left} + {self.right})"
+@dataclassabc
class MatSub(BinaryOp, MatrixExpr, Reducible):
""" Subtraction operator between matrices. """
@@ -610,6 +590,7 @@ class MatSub(BinaryOp, MatrixExpr, Reducible):
return f"({self.left} - {self.right})"
+@dataclassabc
class MatElemMul(BinaryOp, MatrixExpr):
""" Elementwise Matrix Multiplication. """
@@ -648,6 +629,7 @@ class MatElemMul(BinaryOp, MatrixExpr):
return f"({self.left} .* {self.right})"
+@dataclassabc
class MatScalarMul(BinaryOp, MatrixExpr):
""" Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix
on the right. """
@@ -684,6 +666,7 @@ class MatScalarMul(BinaryOp, MatrixExpr):
return f"({self.left} * {self.right})"
+@dataclass(eq=False)
class MatDotProd(BinaryOp, MatrixExpr):
""" Dot product. """
@@ -701,6 +684,7 @@ class MatDotProd(BinaryOp, MatrixExpr):
return Shape.scalar()
+@dataclass(eq=False)
class MatMul(BinaryOp, MatrixExpr):
""" Matrix Multiplication. """
@@ -718,6 +702,7 @@ class MatMul(BinaryOp, MatrixExpr):
return f"({self.left} @ {self.right})"
+@dataclass(eq=False)
class MatTranspose(UnaryOp, MatrixExpr):
""" Matrix transposition """