summaryrefslogtreecommitdiffstats
path: root/polymatrix/mixins/oldpolymatrixexprstatemixin.py
blob: 0b95092de091c2bfeb0b7bf3c1d42454af1ad3d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import abc
import itertools
import dataclasses

from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
from polymatrix.utils import monomial_to_index


class OldPolyMatrixExprStateMixin(abc.ABC):
    @property
    @abc.abstractmethod
    def n_var(self) -> int:
        """ 
        dimension of x 
        """

        ...

    @property
    @abc.abstractmethod
    def n_param(self) -> int:
        """ 
        current number of parameters used in polynomial matrix expressions 
        """

        ...

    @property
    @abc.abstractmethod
    def offset_dict(self) -> dict[tuple[OldPolyMatrixMixin, int], tuple[int, int]]:
        ...

    # @property
    # @abc.abstractmethod
    # def local_index_dict(self) -> dict[tuple[PolyMatrixMixin, int], dict[int, int]]:
    #     ...

    def get_polymat(self, p_index: int) -> tuple[OldPolyMatrixMixin, int, int]:
        for (polymat, degree), (start, end) in self.offset_dict.items():
            if start <= p_index < end:
                return polymat, degree, p_index - start

        raise Exception(f'index {p_index} not found in offset dictionary')

    # def get_num_param_for(self, polymat: PolyMatrixMixin, degree: int) -> int:
    #     start_idx, end_idx = self.offset_dict[(polymat, degree)]
    #     return end_idx - start_idx

    def update_offsets(self, polymats: tuple[OldPolyMatrixMixin]) -> 'OldPolyMatrixExprStateMixin':
        registered_polymats = set(polymat for polymat, _ in self.offset_dict.keys())
        parametric_polymats = set(p for p in polymats if not p.is_constant and p not in registered_polymats)

        if len(parametric_polymats) == 0:
            return self

        else:

            def gen_n_param_per_struct():
                for polymat in parametric_polymats:
                    for degree in polymat.degrees:

                        # number of terms is given by the maximum index + 1
                        number_of_terms = int(polymat.shape[0] * polymat.shape[1] * (monomial_to_index(self.n_var, degree*(self.n_var-1,)) + 1))

                        # print(f'{polymat=}, {number_of_terms=}, {polymat.shape[0] * polymat.shape[1]=}, {monomial_to_index(self.n_var, degree*(self.n_var-1,)) + 1=}')

                        yield (polymat, degree), number_of_terms

            param_key, num_of_terms = tuple(zip(*gen_n_param_per_struct()))
            cum_sum = tuple(itertools.accumulate((self.n_param,) + num_of_terms))
            offset_dict = dict(zip(param_key, itertools.pairwise(cum_sum)))

            return dataclasses.replace(self, offset_dict=self.offset_dict | offset_dict, n_param=cum_sum[-1])

    def update_param(self, n):
        return dataclasses.replace(self, n_param=self.n_param+n)