aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-03-20 11:07:49 +0100
committerNao Pross <np@0hm.ch>2024-03-20 11:08:01 +0100
commit5aa22f11a68c8ae5d0a0584b997fc3c79bc712f0 (patch)
treea901cd604cbe14dc50bf6868ba9ad21b1051ef0c
parentReword exception message (diff)
downloadmdpoly-5aa22f11a68c8ae5d0a0584b997fc3c79bc712f0.tar.gz
mdpoly-5aa22f11a68c8ae5d0a0584b997fc3c79bc712f0.zip
Fix MatSub, separate leaves
Diffstat (limited to '')
-rw-r--r--mdpoly/__init__.py6
-rw-r--r--mdpoly/errors.py1
-rw-r--r--mdpoly/expressions.py126
-rw-r--r--mdpoly/leaves.py128
-rw-r--r--mdpoly/operations/add.py4
5 files changed, 140 insertions, 125 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index afa3e27..2fbdb88 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -107,9 +107,9 @@ TODO: data structure(s) that represent the polynomial
from .index import (Shape as _Shape)
-from .expressions import (WithOps as _WithOps,
- PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam,
- MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam)
+from .expressions import (WithOps as _WithOps)
+from .leaves import (PolyConst as _PolyConst, PolyVar as _PolyVar, PolyParam as _PolyParam,
+ MatConst as _MatConst, MatVar as _MatVar, MatParam as _MatParam)
from .state import State as _State
diff --git a/mdpoly/errors.py b/mdpoly/errors.py
index 262f27e..18e5616 100644
--- a/mdpoly/errors.py
+++ b/mdpoly/errors.py
@@ -7,5 +7,6 @@ class InvalidShape(Exception):
""" This is raised whenever an operation cannot be perfomed because the
shapes of the inputs do not match. """
+
class MissingParameters(Exception):
""" This is raised whenever a parameter was not given. """
diff --git a/mdpoly/expressions.py b/mdpoly/expressions.py
index 6765b93..2824332 100644
--- a/mdpoly/expressions.py
+++ b/mdpoly/expressions.py
@@ -1,14 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
-from dataclassabc import dataclassabc
from dataclasses import dataclass
from functools import wraps
-from typing import Type, TypeVar, Iterable, Callable, Sequence, Any, Self, cast
+from typing import Type, Callable, Any, Self
-from .abc import Expr, Var, Const, Param, Repr
-from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
-from .errors import MissingParameters, AlgebraicError
+from .abc import Expr, Var, Repr
+from .index import Number
+from .leaves import PolyConst
+from .errors import AlgebraicError
from .representations import SparseRepr
from .operations import WithTraceback
@@ -22,122 +22,6 @@ if TYPE_CHECKING:
from .state import State
-# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓
-# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓
-# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛
-
-
-@dataclassabc(frozen=True)
-class MatConst(Const):
- """ Matrix constant. TODO: desc. """
- value: Sequence[Sequence[Number]] # Row major, overloads Const.value
- shape: Shape # overloads Expr.shape
- name: str = "" # overloads Leaf.name
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
-
- for i, row in enumerate(self.value):
- for j, val in enumerate(row):
- r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)
-
- return r, state
-
-
- def __str__(self) -> str:
- if not self.name:
- return repr(self.value)
-
- return self.name
-
-
-T = TypeVar("T", bound=Var)
-
-@dataclassabc(frozen=True)
-class MatVar(Var):
- """ Matrix of polynomial variables. TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape # overloads Expr.shape
-
- # TODO: review this API, can be moved elsewhere?
- def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
- for row in range(self.shape.rows):
- for col in range(self.shape.cols):
- var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
- entry = MatrixIndex(row, col)
-
- yield entry, var
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
-
- # FIXME: do not hardcode scalar type
- for entry, var in self.to_scalars(PolyVar):
- idx = PolyVarIndex.from_var(var, state), # important comma!
- r.set(entry, PolyIndex(idx), 1)
-
- return r, state
-
- def __str__(self) -> str:
- return self.name
-
-
-@dataclassabc(frozen=True)
-class MatParam(Param):
- """ Matrix parameter. TODO: desc. """
- name: str
- shape: Shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- p = state.parameter(self)
- return MatConst(p).to_repr(repr_type, state) # type: ignore[call-arg]
-
- def __str__(self) -> str:
- return self.name
-
-
-# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓
-# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓
-# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛
-
-
-@dataclassabc(frozen=True)
-class PolyVar(Var):
- """ Variable TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
- idx = PolyVarIndex.from_var(self, state), # important comma!
- r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
- return r, state
-
-
-@dataclassabc(frozen=True)
-class PolyConst(Const):
- """ Constant TODO: desc """
- value: Number # overloads Const.value
- name: str = "" # overloads Leaf.name
- shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- r = repr_type(self.shape)
- r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
- return r, state
-
-
-@dataclassabc(frozen=True)
-class PolyParam(Param):
- """ Polynomial parameter TODO: desc """
- name: str # overloads Leaf.name
- shape: Shape = Shape.scalar() # overloads PolyExpr.shape
-
- def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
- p = state.parameter(self)
- return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg]
-
-
# ┏━┓┏━┓┏━╸┏━┓┏━┓╺┳╸╻┏━┓┏┓╻┏━┓
# ┃ ┃┣━┛┣╸ ┣┳┛┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓
# ┗━┛╹ ┗━╸╹┗╸╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
new file mode 100644
index 0000000..bdc9c67
--- /dev/null
+++ b/mdpoly/leaves.py
@@ -0,0 +1,128 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+from dataclassabc import dataclassabc
+from typing import Type, TypeVar, Iterable, Sequence
+
+from .abc import Var, Const, Param
+from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
+
+if TYPE_CHECKING:
+ from .abc import ReprT
+ from .state import State
+
+
+# ┏┳┓┏━┓╺┳╸┏━┓╻┏━╸┏━╸┏━┓
+# ┃┃┃┣━┫ ┃ ┣┳┛┃┃ ┣╸ ┗━┓
+# ╹ ╹╹ ╹ ╹ ╹┗╸╹┗━╸┗━╸┗━┛
+
+
+@dataclassabc(frozen=True)
+class MatConst(Const):
+ """ Matrix constant. TODO: desc. """
+ value: Sequence[Sequence[Number]] # Row major, overloads Const.value
+ shape: Shape # overloads Expr.shape
+ name: str = "" # overloads Leaf.name
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+
+ for i, row in enumerate(self.value):
+ for j, val in enumerate(row):
+ r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)
+
+ return r, state
+
+
+ def __str__(self) -> str:
+ if not self.name:
+ return repr(self.value)
+
+ return self.name
+
+
+T = TypeVar("T", bound=Var)
+
+@dataclassabc(frozen=True)
+class MatVar(Var):
+ """ Matrix of polynomial variables. TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape # overloads Expr.shape
+
+ # TODO: review this API, can be moved elsewhere?
+ def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
+ for row in range(self.shape.rows):
+ for col in range(self.shape.cols):
+ var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
+ entry = MatrixIndex(row, col)
+
+ yield entry, var
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+
+ # FIXME: do not hardcode scalar type
+ for entry, var in self.to_scalars(PolyVar):
+ idx = PolyVarIndex.from_var(var, state), # important comma!
+ r.set(entry, PolyIndex(idx), 1)
+
+ return r, state
+
+ def __str__(self) -> str:
+ return self.name
+
+
+@dataclassabc(frozen=True)
+class MatParam(Param):
+ """ Matrix parameter. TODO: desc. """
+ name: str
+ shape: Shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ p = state.parameter(self)
+ return MatConst(p).to_repr(repr_type, state) # type: ignore[call-arg]
+
+ def __str__(self) -> str:
+ return self.name
+
+
+# ┏━┓┏━┓╻ ╻ ╻┏┓╻┏━┓┏┳┓╻┏━┓╻ ┏━┓
+# ┣━┛┃ ┃┃ ┗┳┛┃┗┫┃ ┃┃┃┃┃┣━┫┃ ┗━┓
+# ╹ ┗━┛┗━╸ ╹ ╹ ╹┗━┛╹ ╹╹╹ ╹┗━╸┗━┛
+
+
+@dataclassabc(frozen=True)
+class PolyVar(Var):
+ """ Variable TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+ idx = PolyVarIndex.from_var(self, state), # important comma!
+ r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
+ return r, state
+
+
+@dataclassabc(frozen=True)
+class PolyConst(Const):
+ """ Constant TODO: desc """
+ value: Number # overloads Const.value
+ name: str = "" # overloads Leaf.name
+ shape: Shape = Shape.scalar() # ovearloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ r = repr_type(self.shape)
+ r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
+ return r, state
+
+
+@dataclassabc(frozen=True)
+class PolyParam(Param):
+ """ Polynomial parameter TODO: desc """
+ name: str # overloads Leaf.name
+ shape: Shape = Shape.scalar() # overloads PolyExpr.shape
+
+ def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
+ p = state.parameter(self)
+ return PolyConst(p).to_repr(repr_type, state) # type: ignore[call-arg]
diff --git a/mdpoly/operations/add.py b/mdpoly/operations/add.py
index 30ad3b4..2871a47 100644
--- a/mdpoly/operations/add.py
+++ b/mdpoly/operations/add.py
@@ -5,9 +5,11 @@ from typing import Type
from dataclassabc import dataclassabc
from ..abc import Expr
+from ..leaves import PolyConst
from ..errors import AlgebraicError
from . import BinaryOp, Reducible
+from .mul import MatScalarMul
if TYPE_CHECKING:
from ..abc import ReprT
@@ -63,7 +65,7 @@ class MatSub(BinaryOp, Reducible):
def reduce(self) -> Expr:
""" See :py:meth:`mdpoly.expressions.Reducible.reduce`. """
- return self.left + (-1 * self.right)
+ return MatAdd(self.left, MatScalarMul(PolyConst(-1), self.right))
def __str__(self) -> str:
return f"({self.left} - {self.right})"