aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mdpoly/__init__.py26
-rw-r--r--mdpoly/abc.py81
-rw-r--r--mdpoly/algebra.py224
-rw-r--r--mdpoly/index.py4
-rw-r--r--mdpoly/leaves.py154
-rw-r--r--mdpoly/state.py6
6 files changed, 297 insertions, 198 deletions
diff --git a/mdpoly/__init__.py b/mdpoly/__init__.py
index 2c007d0..c976db3 100644
--- a/mdpoly/__init__.py
+++ b/mdpoly/__init__.py
@@ -3,14 +3,14 @@
# internal classes imported with underscore because
# they should not be exposed to the end users
-from .abc import (Algebra as _Algebra, Shape as _Shape)
+from .abc import (Shape as _Shape)
-from .algebra import (PolyRingExpr as _PolyRingExpr,
- MatrixExpr as _MatrixExpr)
-
-from .leaves import (Const as _Const, MatConst as _MatConst,
- Var as _Var, MatVar as _MatVar,
- Param as _Param, MatParam as _MatParam)
+from .algebra import (PolyConst as _PolyConst,
+ PolyVar as _PolyVar,
+ PolyParam as _PolyParam,
+ MatConst as _MatConst,
+ MatVar as _MatVar,
+ MatParam as _MatParam)
from .state import State as _State
@@ -40,27 +40,27 @@ class FromHelpers:
yield from map(cls, names)
-class Constant(_Const, _PolyRingExpr, FromHelpers):
+class Constant(_PolyConst, FromHelpers):
""" Constant values """
-class Variable(_Var, _PolyRingExpr, FromHelpers):
+class Variable(_PolyVar, FromHelpers):
""" Polynomial Variable """
-class Parameter(_Param, _PolyRingExpr, FromHelpers):
+class Parameter(_PolyParam, FromHelpers):
""" Parameter that can be substituted """
-class MatrixConstant(_MatConst, _PolyRingExpr):
+class MatrixConstant(_MatConst):
""" Matrix constant """
-class MatrixVariable(_MatVar, _MatrixExpr):
+class MatrixVariable(_MatVar):
""" Matrix Polynomial Variable """
-class MatrixParameter(_MatParam, _MatrixExpr):
+class MatrixParameter(_MatParam):
""" Matrix Parameter """
diff --git a/mdpoly/abc.py b/mdpoly/abc.py
index d7578ec..1914168 100644
--- a/mdpoly/abc.py
+++ b/mdpoly/abc.py
@@ -1,16 +1,25 @@
""" Abstract Base Classes of MDPoly """
from __future__ import annotations
+from typing import TYPE_CHECKING
from .index import Number, Shape, MatrixIndex, PolyIndex
from .constants import NUMERICS_EPS
from .errors import AlgebraicError
-from .state import State
from .util import iszero
-from typing import Self, Any, Iterable, Sequence
+if TYPE_CHECKING:
+ from .state import State
+
+from typing import Self, TypeVar, Generic, Any, Iterable, Sequence
from enum import Enum, auto
from copy import copy
from abc import ABC, abstractmethod
+from dataclassabc import dataclassabc
+
+
+# ┏━╸╻ ╻┏━┓┏━┓┏━╸┏━┓┏━┓╻┏━┓┏┓╻┏━┓
+# ┣╸ ┏╋┛┣━┛┣┳┛┣╸ ┗━┓┗━┓┃┃ ┃┃┗┫┗━┓
+# ┗━╸╹ ╹╹ ╹┗╸┗━╸┗━┛┗━┛╹┗━┛╹ ╹┗━┛
class Algebra(Enum):
@@ -232,10 +241,24 @@ class Expr(ABC):
class Leaf(Expr):
""" Leaf of the binary tree. """
- name: str
+
+ # --- Properties ---
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """ Name of the leaf. """
+
+ # --- Magic methods ---
+
+ def __repr__(self) -> str:
+ return self.name
+
+ # -- Overloads ---
@property
def is_leaf(self):
+ """ Overloads :py:meth:`mdpoly.abc.Expr.is_leaf`. """
return True
@property
@@ -259,12 +282,64 @@ class Leaf(Expr):
(Leaves do not have children). """
+T = TypeVar("T")
+
+
+class Const(Leaf, Generic[T]):
+ """ Leaf to represent a constant. TODO: desc. """
+
+ @property
+ @abstractmethod
+ def value(self) -> T:
+ """ Value of the constant. """
+
+ def __repr__(self) -> str:
+ if not self.name:
+ return repr(self.value)
+
+ return self.name
+
+
+class Var(Leaf):
+ """ Leaf to reprsent a Variable. TODO: desc """
+
+
+class Param(Leaf):
+ """ Parameter. TODO: desc """
+
+
+@dataclassabc(frozen=True)
+class Nothing(Leaf):
+ """
+ A leaf that is nothing. This is a placeholder (eg. for unary operators).
+ """
+ name: str = ""
+ shape: Shape = Shape(0, 0)
+ algebra: Algebra = Algebra.none
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ raise ValueError("Nothing cannot be represented.")
+
+ def __repr__(self) -> str:
+ return "Nothing"
+
+
+# ┏━┓┏━╸╻ ┏━┓╺┳╸╻┏━┓┏┓╻┏━┓
+# ┣┳┛┣╸ ┃ ┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓
+# ╹┗╸┗━╸┗━╸╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛
+
+
class Rel(ABC):
""" Relation between two expressions. """
lhs: Expr
rhs: Expr
+# ┏━┓┏━╸┏━┓┏━┓┏━╸┏━┓┏━╸┏┓╻╺┳╸┏━┓╺┳╸╻┏━┓┏┓╻┏━┓
+# ┣┳┛┣╸ ┣━┛┣┳┛┣╸ ┗━┓┣╸ ┃┗┫ ┃ ┣━┫ ┃ ┃┃ ┃┃┗┫┗━┓
+# ╹┗╸┗━╸╹ ╹┗╸┗━╸┗━┛┗━╸╹ ╹ ╹ ╹ ╹ ╹ ╹┗━┛╹ ╹┗━┛
+
+
class Repr(ABC):
r""" Representation of a multivariate matrix polynomial expression.
diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py
index b98af01..a715b7e 100644
--- a/mdpoly/algebra.py
+++ b/mdpoly/algebra.py
@@ -1,17 +1,17 @@
""" Algebraic Structures for Expressions """
from __future__ import annotations
-from .abc import Algebra, Expr, Repr
-from .leaves import Nothing, Const, Param, Var, MatConst
-from .errors import AlgebraicError, InvalidShape
+from .abc import Algebra, Expr, Repr, Nothing, Const, Var, Param
+from .errors import AlgebraicError, InvalidShape, MissingParameters
from .state import State
-from .index import Shape, MatrixIndex, PolyIndex, Number
+from .index import Shape, MatrixIndex, PolyIndex, PolyVarIndex, Number
-from typing import cast, Iterable
+from typing import cast, Sequence, Iterable, Type, TypeVar
from functools import reduce
from itertools import product, chain, combinations_with_replacement
from math import prod
from abc import abstractmethod
+from dataclassabc import dataclassabc
import operator
@@ -199,9 +199,41 @@ class PolyRingExpr(Expr):
return PolyExp(self, other)
-class PolyConst(Const, PolyRingExpr): ...
-class PolyVar(Const, PolyRingExpr): ...
-class PolyParam(Param, PolyRingExpr): ...
+@dataclassabc(frozen=True, repr=False)
+class PolyVar(Var, PolyRingExpr):
+ """ Variable TODO: desc """
+ name: str
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ r = repr_type()
+ idx = PolyVarIndex.from_var(self, state),
+ r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
+ return r, state
+
+
+@dataclassabc(frozen=True, repr=False)
+class PolyConst(Const, PolyRingExpr):
+ """ Constant TODO: desc """
+ value: Number
+ name: str = ""
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ r = repr_type()
+ r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
+ return r, state
+
+
+@dataclassabc(frozen=True, repr=False)
+class PolyParam(Param, PolyRingExpr):
+ """ Polynomial parameter TODO: desc """
+ name: str
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ if self not in state.parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {self} was not given.")
+
+ return PolyConst(state.parameters[self]).to_repr(repr_type, state)
class PolyAdd(BinaryOp, PolyRingExpr):
@@ -374,32 +406,32 @@ class MatrixExpr(Expr):
def __sub__(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatSub(self, other)
def __rsub__(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatSub(other, self)
def __mul__(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatScalarMul(other, self)
def __rmul__(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatScalarMul(other, self)
def __matmul__(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatMul(self, other)
def __rmatmul(self, other):
self._assert_same_algebra(self, other)
- other = self._wrap(Number, Const, other)
+ other = self._wrap(Number, MatConst, other)
return MatMul(other, self)
def __truediv__(self, scalar):
@@ -412,7 +444,7 @@ class MatrixExpr(Expr):
@property
def T(self) -> MatrixExpr:
- """ Shorthand for :py:meth:`mdpoly.algebra.MatrixAlgebra.transpose`. """
+ """ Shorthand for :py:meth:`mdpoly.algebra.MatrixExpr.transpose`. """
return self.transpose()
def to_scalar(self, scalar_type: type):
@@ -420,9 +452,99 @@ class MatrixExpr(Expr):
raise NotImplementedError
-class MatAdd(BinaryOp, PolyRingExpr):
+@dataclassabc(frozen=True)
+class MatConst(Const, MatrixExpr):
+ """
+ A matrix constant
+ """
+ value: Sequence[Sequence[Number]] # Row major
+ shape: Shape
+ name: str = ""
+ algebra: Algebra = Algebra.matrix_ring
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ r = repr_type()
+
+ for r, row in enumerate(self.value):
+ for c, val in enumerate(row):
+ r.set(MatrixIndex(row=r, col=c), PolyIndex.constant(), val)
+
+ return r, state
+
+
+ def __repr__(self) -> str:
+ if not self.name:
+ return repr(self.value)
+
+ return self.name
+
+
+
+T = TypeVar("T", bound=Var)
+
+@dataclassabc(frozen=True)
+class MatVar(Var, MatrixExpr):
+ """
+ Matrix polynomial variable
+ """
+ name: str
+ shape: Shape
+ algebra: Algebra = Algebra.matrix_ring
+
+ def to_scalars(self, scalar_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
+ for row in range(self.shape.rows):
+ for col in range(self.shape.cols):
+ var = scalar_type(name=f"{self.name}_[{row},{col}]")
+ idx = MatrixIndex(row, col)
+
+ yield idx, var
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ r = repr_type()
+
+ # FIXME: do not hardcode scalar type
+ for idx, var in self.to_scalars(Var):
+ ...
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return self.name
+
+
+@dataclassabc(frozen=True)
+class MatParam(Param, MatrixExpr):
+ """
+ Matrix parameter
+ """
+ name: str
+ shape: Shape
+ algebra: Algebra = Algebra.poly_ring
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ if self not in state.parameters:
+ raise MissingParameters("Cannot construct representation because "
+ f"value for parameter {self} was not given.")
+
+ # FIXME: add conversion to scalar variables
+ return MatConst(state.parameters[self]).to_repr(repr_type, state)
+
+ def __repr__(self) -> str:
+ return self.name
+
+
+class MatAdd(BinaryOp, MatrixExpr):
""" Addition operator between matrices. """
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape == self.right.shape:
+ raise AlgebraicError(f"Cannot add matrices {self.left} and {self.right} "
+ "with different shapes.")
+ return self.left.shape
+
+
def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
""" See :py:meth:`mdpoly.abc.Expr.to_repr`. """
# Make a new empty representation
@@ -444,9 +566,17 @@ class MatAdd(BinaryOp, PolyRingExpr):
return f"({self.left} + {self.right})"
-class MatSub(BinaryOp, PolyRingExpr, Reducible):
+class MatSub(BinaryOp, MatrixExpr, Reducible):
""" Subtraction operator between matrices. """
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape == self.right.shape:
+ raise AlgebraicError(f"Cannot subtract matrices {self.left} and {self.right} "
+ f"with different shapes, {self.left.shape} and {self.right.shape}.")
+ return self.left.shape
+
def reduce(self) -> Expr:
""" See :py:meth:`mdpoly.algebra.Reducible.reduce`. """
return self.left + (-1 * self.right)
@@ -458,23 +588,71 @@ class MatSub(BinaryOp, PolyRingExpr, Reducible):
class MatElemMul(BinaryOp, MatrixExpr):
""" Elementwise Matrix Multiplication. """
+ @property
+ def shape(self) -> Shape:
+ """ See :py:meth:`mdpoly.abc.Expr.shape`. """
+ if not self.left.shape == self.right.shape:
+ raise AlgebraicError("Cannot perform element-wise multiplication of matrices "
+ f"{self.left} and {self.right} with different shapes, "
+ f"{self.left.shape} and {self.right.shape}")
+ return self.left.shape
+
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ r = repr_type()
+
+ lrepr, state = self.left.to_repr(repr_type, state)
+ rrepr, state = self.right.to_repr(repr_type, state)
+
+ # Non zero entries are the intersection since if either is zero the
+ # result is zero
+ nonzero_entries = set(lrepr.entries()) & set(rrepr.entries())
+ for entry in nonzero_entries:
+ # Compute polynomial product between non-zero entries
+ for lterm, rterm in product(lrepr.terms(entry), rrepr.terms(entry)):
+ # Compute where the results should go
+ term = PolyIndex.product(lterm, rterm)
+
+ # Compute product
+ p = r.at(entry, term) + lrepr.at(entry, lterm) * rrepr.at(entry, rterm)
+ r.set(entry, term, p)
+
+ return r, state
+
def __repr__(self) -> str:
return f"({self.left} .* {self.right})"
class MatScalarMul(BinaryOp, MatrixExpr):
- """ Matrix-Scalar Multiplication. """
+ """ Matrix-Scalar Multiplication. Assumes scalar is on the left and matrix
+ on the right. """
@property
def shape(self) -> Shape:
""" See :py:meth:`mdpoly.abc.Expr.shape`. """
- if self.right.shape == Shape.scalar:
- return self.left.shape
+ if not self.left.shape == Shape.scalar():
+ raise InvalidShape(f"Matrix-scalar product assumes that left argumet {self.left} "
+ f"but it has shape {self.left.shape}")
+
+ return self.right.shape
- elif self.left.shape == Shape.scalar:
- return self.right.shape
- raise InvalidShape(f"Either {self.left} or {self.right} must be a scalar.")
+ def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
+ """ See :py:meth:`mdpoly.abc.Expr.to_repr`. """
+ raise NotImplementedError
+ r = repr_type()
+
+ scalar_repr, state = self.left.to_repr(repr_type, state)
+ mat_repr, state = self.right.to_repr(repr_type, state)
+
+ for entry in mat_repr.entries():
+ for term in mat_repr.terms(entry):
+ ...
+
+ return r, state
+
+ def __repr__(self) -> str:
+ return f"({self.left} * {self.right})"
class MatDotProd(BinaryOp, MatrixExpr):
diff --git a/mdpoly/index.py b/mdpoly/index.py
index 02318a4..b1aa8c1 100644
--- a/mdpoly/index.py
+++ b/mdpoly/index.py
@@ -4,11 +4,11 @@ from .errors import InvalidShape
from .util import partition, isclose
from itertools import filterfalse
-from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING
+from typing import Self, NamedTuple, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .state import State
- from .leaves import Var
+ from .abc import Var
Number = int | float
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
deleted file mode 100644
index fe2be49..0000000
--- a/mdpoly/leaves.py
+++ /dev/null
@@ -1,154 +0,0 @@
-from .abc import Algebra, Leaf, Repr
-from .index import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex
-from .state import State
-from .errors import MissingParameters
-
-from dataclasses import dataclass
-from dataclassabc import dataclassabc
-from typing import Iterable
-
-
-@dataclass(frozen=True)
-class Nothing(Leaf):
- """
- A leaf that is nothing. This is a placeholder (eg. for unary operators).
- """
- name: str = ""
- shape: Shape = Shape(0, 0)
- algebra: Algebra = Algebra.none
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- raise ValueError("Nothing cannot be represented.")
-
- def __repr__(self) -> str:
- return "Nothing"
-
-
-# ┏━┓┏━╸┏━┓╻ ┏━┓┏━┓ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓
-# ┗━┓┃ ┣━┫┃ ┣━┫┣┳┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓
-# ┗━┛┗━╸╹ ╹┗━╸╹ ╹╹┗╸ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛
-
-
-@dataclass(frozen=True)
-class Const(Leaf):
- """
- A scalar constant
- """
- value: Number
- name: str = ""
- shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.poly_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- r = repr_type()
- r.set(MatrixIndex.scalar(), PolyIndex.constant(), self.value)
- return r, state
-
- def __repr__(self) -> str:
- if not self.name:
- return repr(self.value)
-
- return self.name
-
-
-@dataclass(frozen=True)
-class Var(Leaf):
- """
- Polynomial variable
- """
- name: str
- shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.poly_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- r = repr_type()
- idx = PolyVarIndex.from_var(self, state),
- r.set(MatrixIndex.scalar(), PolyIndex(idx), 1)
- return r, state
-
- def __repr__(self) -> str:
- return self.name
-
-
-@dataclass(frozen=True)
-class Param(Leaf):
- """
- A parameter
- """
- name: str
- shape: Shape = Shape.scalar()
- algebra: Algebra = Algebra.poly_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- if self not in state.parameters:
- raise MissingParameters("Cannot construct representation because "
- f"value for parameter {self} was not given.")
-
- return Const(state.parameters[self]).to_repr(repr_type, state)
-
- def __repr__(self) -> str:
- return self.name
-
-
-# ┏┳┓┏━┓╺┳╸┏━┓╻╻ ╻ ╻ ┏━╸┏━┓╻ ╻┏━╸┏━┓
-# ┃┃┃┣━┫ ┃ ┣┳┛┃┏╋┛ ┃ ┣╸ ┣━┫┃┏┛┣╸ ┗━┓
-# ╹ ╹╹ ╹ ╹ ╹┗╸╹╹ ╹ ┗━╸┗━╸╹ ╹┗┛ ┗━╸┗━┛
-
-
-@dataclassabc(frozen=True)
-class MatConst(Leaf):
- """
- A matrix constant
- """
- value: Iterable[Iterable[Number]]
- shape: Shape
- name: str = ""
- algebra: Algebra = Algebra.matrix_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- raise UserWarning("MatConst.to_repr is not implemented!")
- return repr_type(), state
-
- def __repr__(self) -> str:
- if not self.name:
- return repr(self.value)
-
- return self.name
-
-
-@dataclassabc(frozen=True)
-class MatVar(Leaf):
- """
- Matrix polynomial variable
- """
- name: str
- shape: Shape
- algebra: Algebra = Algebra.matrix_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- raise UserWarning("MatVar.to_repr is not implemented!")
- return repr_type(), state
-
- def __repr__(self) -> str:
- return self.name
-
-
-@dataclassabc(frozen=True)
-class MatParam(Leaf):
- """
- Matrix parameter
- """
- name: str
- shape: Shape
- algebra: Algebra = Algebra.poly_ring
-
- def to_repr(self, repr_type: type, state: State) -> tuple[Repr, State]:
- if self not in state.parameters:
- raise MissingParameters("Cannot construct representation because "
- f"value for parameter {self} was not given.")
-
- # FIXME: add conversion to scalar variables
- return MatConst(state.parameters[self]).to_repr(repr_type, state)
-
- def __repr__(self) -> str:
- return self.name
diff --git a/mdpoly/state.py b/mdpoly/state.py
index e0a6bc5..6924146 100644
--- a/mdpoly/state.py
+++ b/mdpoly/state.py
@@ -3,10 +3,10 @@ from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from .index import PolyVarIndex
+from .abc import Var, Param, Const
if TYPE_CHECKING:
from .index import Number
- from .leaves import Var, Param, Const
Index = int
@@ -24,9 +24,9 @@ class State:
def index(self, var: Var) -> Index:
""" Get the index for a variable. """
- from .leaves import Var
if not isinstance(var, Var):
- raise IndexError(f"Only variables (type {Var}) can be indexed.")
+ raise IndexError(f"Cannot index {var} (type {type(var)}). "
+ f"Only variables (type {Var}) can be indexed.")
if var not in self.variables.keys():
new_index = self._make_index()