summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/hessianexprmixin.py
blob: 0a89d5b67d8b488fe94c37e060a786b0646c4aef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from abc import abstractmethod
from typing_extensions import override

from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin

from polymatrix.expressionstate import ExpressionState
from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.mixins import PolyMatrixMixin

class HessianExprMixin(ExpressionBaseMixin):
    """
    Compute the hessian matrix (second derivative) of a scalar polynomial.
    """

    @property
    @abstractmethod
    def underlying(self) -> ExpressionBaseMixin:
        """ Expression to take the derivative of """

    @property
    @abstractmethod
    def variables(self) -> ExpressionBaseMixin:
        """
        Variables with respect to which the derivative is taken.
        This must be a column vector.
        """

    @override
    def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
        state, underlying = self.underlying.apply(state)
        state, variables = self.variables.apply(state)

        # Check arguments
        if underlying.shape != (1, 1):
            raise ValueError("Cannot take Hessian of non-scalar expression "
                             f"with shape {underlying.shape}")

    
        if variables.shape[1] != 1:
            raise ValueError("Cannot take variable with repect to matrix "
                             f"with shape {variables.shape}. It must be a column vector")


        # number of variables
        nvars, _ = variables.shape
        variable_indices = []
        for i in range(nvars):
            monomials = tuple(variables.at(i, 0).monomials())
            if len(monomials) > 1:
                # FIXME improve error message
                raise ValueError("Cannot take derivative with respect to polynomial")

            monomial = monomials[0]
            if monomial.degree != 1:
                # FIXME improve error message
                raise ValueError("Cannot take derivative with respect to non-linear term")

            variable_indices.append(monomial[0].index)

        # Actually compute derivative
        result = PolyMatrixDict.empty()
        for i in range(nvars):
            # FIXME: make only upper triangular part and create symmetrix polymatrix
            # FIXME: before that, take out SymmetrixPolyMatrix from SymmetrixExprMixin
            # for j in range(i, nvars):
            for j in range(nvars):
                new_poly = PolyDict.differentiate(
                    PolyDict.differentiate(
                        underlying.scalar(),
                        variable_indices[j]),
                variable_indices[i])

                if new_poly:
                    result[i, j] = new_poly

        return state, init_poly_matrix(result, shape=(nvars, nvars))