From 6b0db138a5cfb1d6c61ab60002873673aab811e1 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Fri, 15 Mar 2024 18:09:02 +0100 Subject: Make expression objects dataclasses --- mdpoly/algebra.py | 73 ++++++++++++++++++++++--------------------------------- 1 file changed, 29 insertions(+), 44 deletions(-) diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index c51113e..4c66765 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -10,60 +10,36 @@ from typing import cast, Sequence, Iterable, Type, TypeVar, Self from functools import reduce from itertools import product, chain, combinations_with_replacement from abc import abstractmethod -from dataclassabc import dataclassabc - +from dataclasses import field, dataclass import operator +from dataclassabc import dataclassabc -class BinaryOp(Expr): - """ Binary Operator. TODO: desc """ - def __init__(self, left, right): - self._left = left - self._right = right - - @property - def left(self) -> Expr: - return self._left - - @left.setter - def left(self, left: Expr) -> None: - self._left = left - - @property - def right(self) -> Expr: - return self._right - @right.setter - def right(self, right: Expr) -> None: - self._right = right +@dataclassabc(eq=False) +class BinaryOp(Expr): + left: Expr = field() + right: Expr = field() +@dataclassabc(eq=False) class UnaryOp(Expr): - """ Unary Operator. TODO: desc. """ - def __init__(self, left, right=None): - self._inner = left + right: Expr = field() @property def inner(self) -> Expr: - return self._inner - - @property - def left(self) -> Expr: - return self._inner - - @left.setter - def left(self, left: Expr) -> None: - self._inner = left + """ Inner expression on which the operator is acting, alias for right. """ + return self.right @property - def right(self) -> Expr: + def left(self) -> Expr: return Nothing() - @right.setter - def right(self, right) -> None: - if not isinstance(right, Nothing): - raise ValueError("Cannot set right of left-acting unary opertator " - "to somethig that is not of type Nothing.") + @left.setter + def left(self, left) -> None: + if not isinstance(left, Nothing): + raise ValueError("Cannot set left of left-acting unary operator " + "to something that is not of type Nothing.") class Reducible(Expr): @@ -250,6 +226,7 @@ class PolyParam(Param, PolyRingExpr): return PolyConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg] +@dataclass(eq=False) class PolyAdd(BinaryOp, PolyRingExpr): """ Addition operator between scalar polynomials. """ @@ -274,6 +251,7 @@ class PolyAdd(BinaryOp, PolyRingExpr): return f"({self.left} + {self.right})" +@dataclass(eq=False) class PolySub(BinaryOp, PolyRingExpr, Reducible): """ Subtraction operator between scalar polynomials. """ @@ -285,6 +263,7 @@ class PolySub(BinaryOp, PolyRingExpr, Reducible): return f"({self.left} - {self.right})" +@dataclass(eq=False) class PolyMul(BinaryOp, PolyRingExpr): """ Multiplication operator between scalar polynomials. """ @@ -314,6 +293,7 @@ class PolyMul(BinaryOp, PolyRingExpr): return f"({self.left} * {self.right})" +@dataclass(eq=False) class PolyExp(BinaryOp, PolyRingExpr, Reducible): """ Exponentiation operator between scalar polynomials. """ @@ -344,12 +324,10 @@ class PolyExp(BinaryOp, PolyRingExpr, Reducible): return f"({self.left} ** {self.right})" +@dataclassabc(eq=False) class PolyPartialDiff(UnaryOp, PolyRingExpr): """ Partial differentiation of scalar polynomials. """ - - def __init__(self, inner: Expr, with_respect_to: PolyVar): - UnaryOp.__init__(self, inner) - self.wrt = with_respect_to + wrt: PolyVar @property def shape(self) -> Shape: @@ -558,6 +536,7 @@ class MatParam(Param, MatrixExpr): return self.name +@dataclassabc class MatAdd(BinaryOp, MatrixExpr): """ Addition operator between matrices. """ @@ -591,6 +570,7 @@ class MatAdd(BinaryOp, MatrixExpr): return f"({self.left} + {self.right})" +@dataclassabc class MatSub(BinaryOp, MatrixExpr, Reducible): """ Subtraction operator between matrices. """ @@ -610,6 +590,7 @@ class MatSub(BinaryOp, MatrixExpr, Reducible): return f"({self.left} - {self.right})" +@dataclassabc class MatElemMul(BinaryOp, MatrixExpr): """ Elementwise Matrix Multiplication. """ @@ -648,6 +629,7 @@ class MatElemMul(BinaryOp, MatrixExpr): return f"({self.left} .* {self.right})" +@dataclassabc class MatScalarMul(BinaryOp, MatrixExpr): """ Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix on the right. """ @@ -684,6 +666,7 @@ class MatScalarMul(BinaryOp, MatrixExpr): return f"({self.left} * {self.right})" +@dataclass(eq=False) class MatDotProd(BinaryOp, MatrixExpr): """ Dot product. """ @@ -701,6 +684,7 @@ class MatDotProd(BinaryOp, MatrixExpr): return Shape.scalar() +@dataclass(eq=False) class MatMul(BinaryOp, MatrixExpr): """ Matrix Multiplication. """ @@ -718,6 +702,7 @@ class MatMul(BinaryOp, MatrixExpr): return f"({self.left} @ {self.right})" +@dataclass(eq=False) class MatTranspose(UnaryOp, MatrixExpr): """ Matrix transposition """ -- cgit v1.2.1