summaryrefslogtreecommitdiffstats
path: root/sumofsquares/sosexprbase/init/initsosexprbase.py
blob: f8f97fd15a6219eb61d1b83fb03f9351f849ee57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import polymatrix

from sumofsquares.sosexprbase.impl import ParamSOSExprBaseImpl, SOSExprBaseImpl
from sumofsquares.sosexprbase.mixins.parametermixin import ParameterMixin


def init_sos_expr_base(
    expr: polymatrix.Expression, 
    variables: polymatrix.Expression,
    dependence: tuple[ParameterMixin] | None = None,
):
    
    if not isinstance(expr, polymatrix.Expression):
        expr = polymatrix.from_(expr)

    # if variables is None:
    #     variables = polymatrix.from_(1)

    if dependence is None:
        dependence = tuple()

    return SOSExprBaseImpl(
        expr=expr,
        variables=variables,
        dependence=dependence,
    )


def init_param_sos_expr_base(
    name: str, 
    variables: polymatrix.Expression,
    monom: polymatrix.Expression | None = None,
    n_row: int | None = None,
    n_col: int | None = None,
):
    if monom is None:
        monom = polymatrix.from_(1)

    if n_row == None:
        n_row = 1

    if n_col == None:
        n_col = 1

    if n_row == 1 and n_col == 1:
        param = monom.parametrize(f'{name}')
        param_matrix = param.T

    else:
        param = monom.rep_mat(n_col * n_row, 1).parametrize(f'{name}')
        param_matrix = param.reshape(monom, -1).T

        # params = tuple(monom.parametrize(f'{name}_{row+1}_{col+1}') for col in range(n_col) for row in range(n_row))
        # param = polymatrix.v_stack(params)
        # param_matrix = polymatrix.v_stack(tuple(param.T for param in params))

    return ParamSOSExprBaseImpl(
        name=name,
        param=param,
        monom=monom,
        variables=variables,
        param_matrix=param_matrix,
        n_row=n_row,
    )