From 5aa22f11a68c8ae5d0a0584b997fc3c79bc712f0 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Wed, 20 Mar 2024 11:07:49 +0100 Subject: Fix MatSub, separate leaves --- mdpoly/__init__.py | 6 +-- mdpoly/errors.py | 1 + mdpoly/expressions.py | 126 ++-------------------------------------------- mdpoly/leaves.py | 128 +++++++++++++++++++++++++++++++++++++++++++++++ mdpoly/operations/add.py | 4 +- 5 files changed, 140 insertions(+), 125 deletions(-) create mode 100644 mdpoly/leaves.py diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py index afa3e27..2fbdb88 100644 --- a/mdpoly/__init__.py +++ b/mdpoly/__init__.py @@ -107,9 +107,9 @@ TODO: data structure(s) that represent the polynomial from .index import (Shape as _Shape) -from .expressions import (WithOps as _WithOps, - PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam, - MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam) +from .expressions import (WithOps as _WithOps) +from .leaves import (PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam, + MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam) from .state import State as _State diff --git a/mdpoly/errors.py b/mdpoly/errors.py index 262f27e..18e5616 100644 --- a/mdpoly/errors.py +++ b/mdpoly/errors.py @@ -7,5 +7,6 @@ class InvalidShape(Exception): """ This is raised whenever an operation cannot be perfomed because the shapes of the inputs do not match. """ + class MissingParameters(Exception): """ This is raised whenever a parameter was not given. """ diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py index 6765b93..2824332 100644 --- a/mdpoly/expressions.py +++ b/mdpoly/expressions.py @@ -1,14 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dataclassabc import dataclassabc from dataclasses import dataclass from functools import wraps -from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast +from typing import Type, Callable, Any, Self -from .abc import Expr, Var, Const, Param, Repr -from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number -from .errors import MissingParameters, AlgebraicError +from .abc import Expr, Var, Repr +from .index import Number +from .leaves import PolyConst +from .errors import AlgebraicError from .representations import SparseRepr from .operations import WithTraceback @@ -22,122 +22,6 @@ if TYPE_CHECKING: from .state import State -# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓ -# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓ -# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛ - - -@dataclassabc(frozen=True) -class MatConst(Const): - """ Matrix constant. TODO: desc. """ - value: Sequence[Sequence[Number]] # Row major, overloads Const.value - shape: Shape # overloads Expr.shape - name: str = "" # overloads Leaf.name - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - - for i, row in enumerate(self.value): - for j, val in enumerate(row): - r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val) - - return r, state - - - def __str__(self) -> str: - if not self.name: - return repr(self.value) - - return self.name - - -T = TypeVar("T", bound=Var) - -@dataclassabc(frozen=True) -class MatVar(Var): - """ Matrix of polynomial variables. TODO: desc """ - name: str # overloads Leaf.name - shape: Shape # overloads Expr.shape - - # TODO: review this API, can be moved elsewhere? - def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]: - for row in range(self.shape.rows): - for col in range(self.shape.cols): - var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg] - entry = MatrixIndex(row, col) - - yield entry, var - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - - # FIXME: do not hardcode scalar type - for entry, var in self.to_scalars(PolyVar): - idx = PolyVarIndex.from_var(var, state), # important comma! - r.set(entry, PolyIndex(idx), 1) - - return r, state - - def __str__(self) -> str: - return self.name - - -@dataclassabc(frozen=True) -class MatParam(Param): - """ Matrix parameter. TODO: desc. """ - name: str - shape: Shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - p = state.parameter(self) - return MatConst(p).to_repr(repr_type, state) # type: ignore[call-arg] - - def __str__(self) -> str: - return self.name - - -# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓ -# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓ -# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛ - - -@dataclassabc(frozen=True) -class PolyVar(Var): - """ Variable TODO: desc """ - name: str # overloads Leaf.name - shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - idx = PolyVarIndex.from_var(self, state), # important comma! - r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) - return r, state - - -@dataclassabc(frozen=True) -class PolyConst(Const): - """ Constant TODO: desc """ - value: Number # overloads Const.value - name: str = "" # overloads Leaf.name - shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - r = repr_type(self.shape) - r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) - return r, state - - -@dataclassabc(frozen=True) -class PolyParam(Param): - """ Polynomial parameter TODO: desc """ - name: str # overloads Leaf.name - shape: Shape = Shape.scalar() # overloads PolyExpr.shape - - def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: - p = state.parameter(self) - return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg] - - # ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓ # ┃ ┃┣━┛┣╸ ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓ # ┗━┛╹ ┗━╸╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛ diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py new file mode 100644 index 0000000..bdc9c67 --- /dev/null +++ b/mdpoly/leaves.py @@ -0,0 +1,128 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from dataclassabc import dataclassabc +from typing import Type, TypeVar, Iterable, Sequence + +from .abc import Var, Const, Param +from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number + +if TYPE_CHECKING: + from .abc import ReprT + from .state import State + + +# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓ +# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓ +# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛ + + +@dataclassabc(frozen=True) +class MatConst(Const): + """ Matrix constant. TODO: desc. """ + value: Sequence[Sequence[Number]] # Row major, overloads Const.value + shape: Shape # overloads Expr.shape + name: str = "" # overloads Leaf.name + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + + for i, row in enumerate(self.value): + for j, val in enumerate(row): + r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val) + + return r, state + + + def __str__(self) -> str: + if not self.name: + return repr(self.value) + + return self.name + + +T = TypeVar("T", bound=Var) + +@dataclassabc(frozen=True) +class MatVar(Var): + """ Matrix of polynomial variables. TODO: desc """ + name: str # overloads Leaf.name + shape: Shape # overloads Expr.shape + + # TODO: review this API, can be moved elsewhere? + def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]: + for row in range(self.shape.rows): + for col in range(self.shape.cols): + var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg] + entry = MatrixIndex(row, col) + + yield entry, var + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + + # FIXME: do not hardcode scalar type + for entry, var in self.to_scalars(PolyVar): + idx = PolyVarIndex.from_var(var, state), # important comma! + r.set(entry, PolyIndex(idx), 1) + + return r, state + + def __str__(self) -> str: + return self.name + + +@dataclassabc(frozen=True) +class MatParam(Param): + """ Matrix parameter. TODO: desc. """ + name: str + shape: Shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + p = state.parameter(self) + return MatConst(p).to_repr(repr_type, state) # type: ignore[call-arg] + + def __str__(self) -> str: + return self.name + + +# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓ +# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓ +# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛ + + +@dataclassabc(frozen=True) +class PolyVar(Var): + """ Variable TODO: desc """ + name: str # overloads Leaf.name + shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + idx = PolyVarIndex.from_var(self, state), # important comma! + r.set(MatrixIndex.scalar(), PolyIndex(idx), 1) + return r, state + + +@dataclassabc(frozen=True) +class PolyConst(Const): + """ Constant TODO: desc """ + value: Number # overloads Const.value + name: str = "" # overloads Leaf.name + shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + r = repr_type(self.shape) + r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value) + return r, state + + +@dataclassabc(frozen=True) +class PolyParam(Param): + """ Polynomial parameter TODO: desc """ + name: str # overloads Leaf.name + shape: Shape = Shape.scalar() # overloads PolyExpr.shape + + def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]: + p = state.parameter(self) + return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg] diff --git a/mdpoly/operations/add.py b/mdpoly/operations/add.py index 30ad3b4..2871a47 100644 --- a/mdpoly/operations/add.py +++ b/mdpoly/operations/add.py @@ -5,9 +5,11 @@ from typing import Type from dataclassabc import dataclassabc from ..abc import Expr +from ..leaves import PolyConst from ..errors import AlgebraicError from . import BinaryOp, Reducible +from .mul import MatScalarMul if TYPE_CHECKING: from ..abc import ReprT @@ -63,7 +65,7 @@ class MatSub(BinaryOp, Reducible): def reduce(self) -> Expr: """ See :py:meth:`mdpoly.expressions.Reducible.reduce`. """ - return self.left + (-1 * self.right) + return MatAdd(self.left, MatScalarMul(PolyConst(-1), self.right)) def __str__(self) -> str: return f"({self.left} - {self.right})" -- cgit v1.2.1