aboutsummaryrefslogtreecommitdiffstats
path: root/mdpoly/expressions/matrix.py
blob: 65f86c2935bc5f6456cdd1e42e89a9ba50c9b7fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from __future__ import annotations
from typing import TYPE_CHECKING

from typing import Type, TypeVar, Iterable, Sequence
from dataclassabc import dataclassabc

from .poly import PolyVar, PolyConst
from ..abc import Expr, Var, Const, Param, Algebra
from ..index import MatrixIndex, PolyIndex, PolyVarIndex
from ..errors import MissingParameters

from .. import operations

if TYPE_CHECKING:
    from ..abc import ReprT
    from ..index import Shape, Number
    from ..state import State


class MatrixExpr(Expr):
    r""" Expression with the algebraic properties of a matrix ring and / or
    module (depending on the shape).

    We denote with :math:`R` a polynomial ring.

    - If the shape is square, like :math:`(n, n)` then this is :math:`M_n(R)`
      the ring of matrices over :math:`R`.

    - If the shape is something else like a row or column (:math:`(m, p)`,
      :math:`(1, n)` or :math:`(n, 1)`) this is a module, i.e. an algebra with
      addition and scalar multiplication, where the "scalars" come from
      :math:`R`.
    
    Furthermore some operators that are usually expected from matrices are
    already included (eg. transposition).
    """

    @property
    def algebra(self) -> Algebra:
        return Algebra.matrix_ring

    def __add__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.add.MatAdd(self, other)

    def __sub__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.add.MatSub(self, other)

    def __rsub__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.add.MatSub(other, self)

    def __neg__(self):
        # FIXME: Create PolyNeg?
        return operations.mul.MatScalarMul(PolyConst(-1), self)

    def __mul__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)

        # TODO: case distiction based on shapes
        return operations.mul.MatScalarMul(other, self) 


    def __rmul__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.mul.MatScalarMul(other, self)

    def __matmul__(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.MatMul(self, other)

    def __rmatmul(self, other):
        self._assert_same_algebra(self, other)
        other = self._wrap(Number, MatConst, other)
        return operations.mul.MatMul(other, self)

    def __truediv__(self, scalar):
        scalar = self._wrap_if_constant(scalar)
        raise NotImplementedError

    def transpose(self) -> MatrixExpr:
        """ Matrix transposition. """
        raise NotImplementedError
        return operations.transpose.MatTranspose(self)

    @property
    def T(self) -> MatrixExpr:
        """ Shorthand for :py:meth:`mdpoly.expressions.matrix.MatrixExpr.transpose`. """
        return self.transpose()

    def to_scalar(self, scalar_type: type):
        """ Convert to a scalar expression. """
        raise NotImplementedError


@dataclassabc(frozen=True)
class MatConst(Const, MatrixExpr):
    """ Matrix constant. TODO: desc. """
    value: Sequence[Sequence[Number]] # Row major, overloads Const.value
    shape: Shape # overloads Expr.shape
    name: str = "" # overloads Leaf.name

    def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
        r = repr_type(self.shape)

        for i, row in enumerate(self.value):
            for j, val in enumerate(row):
                r.set(MatrixIndex(row=i, col=j), PolyIndex.constant(), val)

        return r, state


    def __str__(self) -> str:
        if not self.name:
            return repr(self.value)

        return self.name



T = TypeVar("T", bound=Var)

@dataclassabc(frozen=True)
class MatVar(Var, MatrixExpr):
    """ Matrix of polynomial variables. TODO: desc """
    name: str # overloads Leaf.name
    shape: Shape # overloads Expr.shape

    # TODO: review this API, can be moved elsewhere?
    def to_scalars(self, scalar_var_type: Type[T]) -> Iterable[tuple[MatrixIndex, T]]:
        for row in range(self.shape.rows):
            for col in range(self.shape.cols):
                var = scalar_var_type(name=f"{self.name}_[{row},{col}]") # type: ignore[call-arg]
                entry = MatrixIndex(row, col)

                yield entry, var

    def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
        r = repr_type(self.shape)
        
        # FIXME: do not hardcode scalar type
        for entry, var in self.to_scalars(PolyVar):
            idx = PolyVarIndex.from_var(var, state), # important comma!
            r.set(entry, PolyIndex(idx), 1)

        return r, state

    def __str__(self) -> str:
        return self.name


@dataclassabc(frozen=True)
class MatParam(Param, MatrixExpr):
    """ Matrix parameter. TODO: desc. """
    name: str
    shape: Shape

    def to_repr(self, repr_type: Type[ReprT], state: State) -> tuple[ReprT, State]:
        if self not in state.parameters:
            raise MissingParameters("Cannot construct representation because "
                                    f"value for parameter {self} was not given.")

        # FIXME: add conversion to scalar variables
        # Ignore typecheck because dataclassabc has not type stub
        return MatConst(state.parameters[self]).to_repr(repr_type, state) # type: ignore[call-arg]

    def __str__(self) -> str:
        return self.name