summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py96
1 files changed, 51 insertions, 45 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 1c072b4..8ee1488 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -4,6 +4,7 @@ import itertools
import typing
import numpy as np
import scipy.sparse
+import sympy
from polymatrix.expression.expression import Expression
from polymatrix.expressionstate.expressionstate import ExpressionState
@@ -20,10 +21,10 @@ from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
from polymatrix.expression.utils.monomialtoindex import monomial_to_index
from polymatrix.expressionstate.init.initexpressionstate import init_expression_state as original_init_expression_state
+
def init_expression_state():
return original_init_expression_state()
-
def from_sympy(
data: tuple[tuple[float]],
):
@@ -31,13 +32,11 @@ def from_sympy(
init_from_sympy_expr(data)
)
-
def from_(
data: tuple[tuple[float]],
):
return from_sympy(data)
-
def from_polymatrix(
polymatrix: PolyMatrix,
):
@@ -45,7 +44,6 @@ def from_polymatrix(
init_from_terms_expr(polymatrix)
)
-
def v_stack(
expressions: tuple[Expression],
):
@@ -53,7 +51,6 @@ def v_stack(
init_v_stack_expr(expressions)
)
-
def h_stack(
expressions: tuple[Expression],
):
@@ -61,7 +58,6 @@ def h_stack(
init_v_stack_expr(tuple(expr.T for expr in expressions))
).T
-
def block_diag(
expressions: tuple[Expression],
):
@@ -69,7 +65,6 @@ def block_diag(
init_block_diag_expr(expressions)
)
-
def eye(
variable: tuple[Expression],
):
@@ -77,16 +72,6 @@ def eye(
init_eye_expr(variable=variable)
)
-
-# def sos_matrix(
-# underlying: Expression,
-# ):
-# def func(state: ExpressionState):
-# underlying =
-
-# return init_state_monad(func)
-
-
def kkt_equality(
variables: Expression,
equality: Expression = None,
@@ -269,6 +254,37 @@ def kkt_inequality(
return init_state_monad(func)
+def rows(
+ expr: Expression,
+) -> StateMonadMixin[ExpressionState, np.ndarray]:
+
+ def func(state: ExpressionState):
+ state, underlying = expr.apply(state)
+
+ def gen_row_terms():
+ for row in range(underlying.shape[0]):
+
+ terms = {}
+
+ for col in range(underlying.shape[1]):
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ terms[0, col] = underlying_terms
+
+ yield init_expression(underlying=init_from_terms_expr(
+ terms=terms,
+ shape=(1, underlying.shape[1])
+ ))
+
+ row_terms = tuple(gen_row_terms())
+
+ return state, row_terms
+
+ return init_state_monad(func)
+
def shape(
expr: Expression,
) -> StateMonadMixin[ExpressionState, tuple[int, ...]]:
@@ -550,50 +566,40 @@ def to_constant_matrix(
A = np.zeros(underlying.shape, dtype=np.float32)
- for row in range(underlying.shape[0]):
- for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
-
- for monomial, value in underlying_terms.items():
+ for (row, col), polynomial in underlying.get_terms():
+ for monomial, value in polynomial.items():
- if len(monomial) == 0:
- A[row, col] = value
+ if len(monomial) == 0:
+ A[row, col] = value
return state, A
return init_state_monad(func)
-
-def rows(
+def to_sympy_expr(
expr: Expression,
-) -> StateMonadMixin[ExpressionState, np.ndarray]:
+) -> StateMonadMixin[ExpressionState, sympy.Expr]:
def func(state: ExpressionState):
state, underlying = expr.apply(state)
- def gen_row_terms():
- for row in range(underlying.shape[0]):
+ A = np.zeros(underlying.shape, dtype=np.object)
- terms = {}
+ for (row, col), polynomial in underlying.get_terms():
- for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
+ sympy_polynomial = 0
- terms[0, col] = underlying_terms
+ for monomial, value in polynomial.items():
+ sympy_monomial = 1
- yield init_expression(underlying=init_from_terms_expr(
- terms=terms,
- shape=(1, underlying.shape[1])
- ))
+ for offset, count in monomial:
+ var = state.get_variable_from_offset(offset)
+ sympy_monomial *= var**count
- row_terms = tuple(gen_row_terms())
+ sympy_polynomial += value * sympy_monomial
- return state, row_terms
+ A[row, col] = sympy_polynomial
+
+ return state, A
return init_state_monad(func)