diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/__init__.py | 96 |
1 files changed, 51 insertions, 45 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 1c072b4..8ee1488 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -4,6 +4,7 @@ import itertools import typing import numpy as np import scipy.sparse +import sympy from polymatrix.expression.expression import Expression from polymatrix.expressionstate.expressionstate import ExpressionState @@ -20,10 +21,10 @@ from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin from polymatrix.expression.utils.monomialtoindex import monomial_to_index from polymatrix.expressionstate.init.initexpressionstate import init_expression_state as original_init_expression_state + def init_expression_state(): return original_init_expression_state() - def from_sympy( data: tuple[tuple[float]], ): @@ -31,13 +32,11 @@ def from_sympy( init_from_sympy_expr(data) ) - def from_( data: tuple[tuple[float]], ): return from_sympy(data) - def from_polymatrix( polymatrix: PolyMatrix, ): @@ -45,7 +44,6 @@ def from_polymatrix( init_from_terms_expr(polymatrix) ) - def v_stack( expressions: tuple[Expression], ): @@ -53,7 +51,6 @@ def v_stack( init_v_stack_expr(expressions) ) - def h_stack( expressions: tuple[Expression], ): @@ -61,7 +58,6 @@ def h_stack( init_v_stack_expr(tuple(expr.T for expr in expressions)) ).T - def block_diag( expressions: tuple[Expression], ): @@ -69,7 +65,6 @@ def block_diag( init_block_diag_expr(expressions) ) - def eye( variable: tuple[Expression], ): @@ -77,16 +72,6 @@ def eye( init_eye_expr(variable=variable) ) - -# def sos_matrix( -# underlying: Expression, -# ): -# def func(state: ExpressionState): -# underlying = - -# return init_state_monad(func) - - def kkt_equality( variables: Expression, equality: Expression = None, @@ -269,6 +254,37 @@ def kkt_inequality( return init_state_monad(func) +def rows( + expr: Expression, +) -> StateMonadMixin[ExpressionState, np.ndarray]: + + def func(state: ExpressionState): + state, underlying = expr.apply(state) + + def gen_row_terms(): + for row in range(underlying.shape[0]): + + terms = {} + + for col in range(underlying.shape[1]): + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + terms[0, col] = underlying_terms + + yield init_expression(underlying=init_from_terms_expr( + terms=terms, + shape=(1, underlying.shape[1]) + )) + + row_terms = tuple(gen_row_terms()) + + return state, row_terms + + return init_state_monad(func) + def shape( expr: Expression, ) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: @@ -550,50 +566,40 @@ def to_constant_matrix( A = np.zeros(underlying.shape, dtype=np.float32) - for row in range(underlying.shape[0]): - for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue - - for monomial, value in underlying_terms.items(): + for (row, col), polynomial in underlying.get_terms(): + for monomial, value in polynomial.items(): - if len(monomial) == 0: - A[row, col] = value + if len(monomial) == 0: + A[row, col] = value return state, A return init_state_monad(func) - -def rows( +def to_sympy_expr( expr: Expression, -) -> StateMonadMixin[ExpressionState, np.ndarray]: +) -> StateMonadMixin[ExpressionState, sympy.Expr]: def func(state: ExpressionState): state, underlying = expr.apply(state) - def gen_row_terms(): - for row in range(underlying.shape[0]): + A = np.zeros(underlying.shape, dtype=np.object) - terms = {} + for (row, col), polynomial in underlying.get_terms(): - for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue + sympy_polynomial = 0 - terms[0, col] = underlying_terms + for monomial, value in polynomial.items(): + sympy_monomial = 1 - yield init_expression(underlying=init_from_terms_expr( - terms=terms, - shape=(1, underlying.shape[1]) - )) + for offset, count in monomial: + var = state.get_variable_from_offset(offset) + sympy_monomial *= var**count - row_terms = tuple(gen_row_terms()) + sympy_polynomial += value * sympy_monomial - return state, row_terms + A[row, col] = sympy_polynomial + + return state, A return init_state_monad(func) |