summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-05-03 11:13:25 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-05-03 11:13:25 +0200
commita9aa5ea1b606a160596f684cc3c4d3b9a578f287 (patch)
treeea1993a9cac1b871321f75244a636599f6ec3203
parentbug fixes and clean ups (diff)
downloadpolymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.tar.gz
polymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.zip
add statemonad syntax
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py258
-rw-r--r--polymatrix/expression/forallexpr.py4
-rw-r--r--polymatrix/expression/impl/linearinexprimpl.py (renamed from polymatrix/expression/impl/forallexprimpl.py)4
-rw-r--r--polymatrix/expression/impl/linearmatrixinexprimpl.py9
-rw-r--r--polymatrix/expression/init/initevalexpr.py14
-rw-r--r--polymatrix/expression/init/initfromarrayexpr.py2
-rw-r--r--polymatrix/expression/init/initlinearinexpr.py (renamed from polymatrix/expression/init/initforallexpr.py)6
-rw-r--r--polymatrix/expression/init/initlinearmatrixinexpr.py12
-rw-r--r--polymatrix/expression/linearinexpr.py4
-rw-r--r--polymatrix/expression/linearmatrixinexpr.py4
-rw-r--r--polymatrix/expression/mixins/accumulateexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py46
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/expressionbasemixin.py15
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py22
-rw-r--r--polymatrix/expression/mixins/fromarrayexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/kktexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py (renamed from polymatrix/expression/mixins/forallexprmixin.py)11
-rw-r--r--polymatrix/expression/mixins/linearmatrixinexprmixin.py64
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/parametrizetermsexprmixin.py99
-rw-r--r--polymatrix/expression/mixins/polymatrixmixin.py162
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py8
-rw-r--r--polymatrix/expression/utils/getvariableindices.py26
-rw-r--r--polymatrix/polysolver.py99
-rw-r--r--polymatrix/statemonad/__init__.py16
-rw-r--r--polymatrix/statemonad/impl/__init__.py0
-rw-r--r--polymatrix/statemonad/impl/statemonadimpl.py8
-rw-r--r--polymatrix/statemonad/init/__init__.py0
-rw-r--r--polymatrix/statemonad/init/initstatemonad.py10
-rw-r--r--polymatrix/statemonad/mixins/__init__.py0
-rw-r--r--polymatrix/statemonad/mixins/statemonadmixin.py53
-rw-r--r--polymatrix/statemonad/statemonad.py4
43 files changed, 711 insertions, 326 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 78b4afa..139e391 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -1,11 +1,22 @@
+import itertools
+import numpy as np
+import scipy.sparse
+# import polymatrix.statemonad
+
from polymatrix.expression.expression import Expression
+from polymatrix.expression.expressionstate import ExpressionState
from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initexpression import init_expression
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
from polymatrix.expression.init.initkktexpr import init_kkt_expr
+from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr
from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.statemonad.init.initstatemonad import init_state_monad
+from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
+from polymatrix.utils import monomial_to_index
def from_(
@@ -66,4 +77,249 @@ def kkt(
)
)
-
+ # self_cost = cost
+ # self_variables = variables
+ # self_equality = equality
+
+ # def func(state: ExpressionState):
+ # state, cost = self_cost.apply(state=state)
+
+ # assert cost.shape[1] == 1
+
+ # if self_equality is not None:
+
+ # state, equality = self_equality.apply(state=state)
+
+ # state, equality_der = self_equality.diff(
+ # self_variables,
+ # introduce_derivatives=True,
+ # ).apply(state)
+
+ # assert cost.shape[0] == equality_der.shape[1]
+
+ # def acc_nu_variables(acc, v):
+ # state, nu_variables = acc
+
+ # nu_variable = state.n_param
+ # state = state.register(n_param=1)
+
+ # return state, nu_variables + [nu_variable]
+
+ # *_, (state, nu_variables) = tuple(itertools.accumulate(
+ # range(equality.shape[0]),
+ # acc_nu_variables,
+ # initial=(state, []),
+ # ))
+
+ # else:
+ # nu_variables = tuple()
+
+ # idx_start = 0
+
+ # terms = {}
+
+ # for row in range(cost.shape[0]):
+ # try:
+ # monomial_terms = cost.get_poly(row, 0)
+ # except KeyError:
+ # monomial_terms = {}
+
+ # for eq_idx, nu_variable in enumerate(nu_variables):
+
+ # try:
+ # underlying_terms = equality_der.get_poly(eq_idx, row)
+ # except KeyError:
+ # continue
+
+ # for monomial, value in underlying_terms.items():
+ # new_monomial = monomial + (nu_variable,)
+
+ # if new_monomial not in monomial_terms:
+ # monomial_terms[new_monomial] = 0
+
+ # monomial_terms[new_monomial] += value
+
+ # terms[idx_start, 0] = monomial_terms
+ # idx_start += 1
+
+ # cost_expr = init_expression(init_from_terms_expr(
+ # terms=terms,
+ # shape=(idx_start, 1),
+ # ))
+
+ # terms = {}
+ # for eq_idx, nu_variable in enumerate(nu_variables):
+ # terms[eq_idx, 0] = {(nu_variable,): 1}
+
+ # nu_expr = init_expression(init_from_terms_expr(
+ # terms=terms,
+ # shape=(len(nu_variables), 1),
+ # ))
+
+ # return state, (cost_expr, nu_expr)
+
+ # return StateMonadMixin.init(func)
+
+# def to_linear_matrix(
+# expr: Expression,
+# variables: Expression,
+# ) -> StateMonad[ExpressionState, tuple[Expression, ...]]:
+# def func(state: ExpressionState):
+# state, variable_indices = get_variable_indices(state, variables)
+
+# def gen_matrices():
+# for variable_index in variable_indices:
+# yield init_linear_matrix_in_expr(
+# underlying=expr,
+# variable=variable_index,
+# )
+
+# matrices = tuple(gen_matrices())
+
+# return state, matrices
+
+# return StateMonad.init(func)
+
+
+def to_matrix_equations(
+ expr: Expression,
+) -> StateMonadMixin[ExpressionState, tuple[tuple[np.ndarray, ...], tuple[int, ...]]]:
+ def func(state: ExpressionState):
+ state, underlying = expr.apply(state)
+
+ assert underlying.shape[1] == 1
+
+ def gen_used_variables():
+ def gen_used_auxillary_variables(considered):
+ monomial_terms = state.auxillary_equations[considered[-1]]
+ for monomial in monomial_terms.keys():
+ for variable in monomial:
+ yield variable
+
+ if variable not in considered and variable in state.auxillary_equations:
+ yield from gen_used_auxillary_variables(considered + (variable,))
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial in underlying_terms.keys():
+ for variable in monomial:
+ yield variable
+
+ if variable in state.auxillary_equations:
+ yield from gen_used_auxillary_variables((variable,))
+
+ used_variables = set(gen_used_variables())
+
+ ordered_variable_index = tuple(sorted(used_variables))
+ variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
+
+ n_param = len(ordered_variable_index)
+
+ A = np.zeros((n_param, 1), dtype=np.float32)
+ B = np.zeros((n_param, n_param), dtype=np.float32)
+ C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32)
+
+ def populate_matrices(monomial_terms, row):
+ for monomial, value in monomial_terms.items():
+ new_monomial = tuple(variable_index_map[var] for var in monomial)
+ col = monomial_to_index(n_param, new_monomial)
+
+ match len(new_monomial):
+ case 0:
+ A[row, col] = value
+ case 1:
+ B[row, col] = value
+ case 2:
+ C[row, col] = value
+ case _:
+ raise Exception(f'illegal case {new_monomial=}')
+
+ for row in range(underlying.shape[0]):
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ populate_matrices(
+ monomial_terms=underlying_terms,
+ row=row,
+ )
+
+ current_row = underlying.shape[0]
+
+ for key, monomial_terms in state.auxillary_equations.items():
+ if key in ordered_variable_index:
+ populate_matrices(
+ monomial_terms=monomial_terms,
+ row=current_row,
+ )
+ current_row += 1
+
+ # assert current_row == n_param, f'{current_row} is not {n_param}'
+
+ return state, ((A, B, C), ordered_variable_index)
+
+ return StateMonadMixin.init(func)
+
+def to_constant_matrix(
+ expr: Expression,
+) -> StateMonadMixin[ExpressionState, np.ndarray]:
+
+ def func(state: ExpressionState):
+ state, underlying = expr.apply(state)
+
+ A = np.zeros(underlying.shape, dtype=np.float32)
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+
+ if len(monomial) == 0:
+ A[row, col] = value
+
+ return state, A
+
+ return init_state_monad(func)
+
+
+def rows(
+ expr: Expression,
+) -> StateMonadMixin[ExpressionState, np.ndarray]:
+
+ def func(state: ExpressionState):
+ state, underlying = expr.apply(state)
+
+ def gen_row_terms():
+ for row in range(underlying.shape[0]):
+
+ terms = {}
+
+ for col in range(underlying.shape[1]):
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ terms[0, col] = underlying_terms
+
+ yield init_expression(underlying=init_from_terms_expr(
+ terms=terms,
+ shape=(1, underlying.shape[1])
+ ))
+
+ row_terms = tuple(gen_row_terms())
+
+ return state, row_terms
+
+ return init_state_monad(func)
diff --git a/polymatrix/expression/forallexpr.py b/polymatrix/expression/forallexpr.py
deleted file mode 100644
index 972d553..0000000
--- a/polymatrix/expression/forallexpr.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from polymatrix.expression.mixins.forallexprmixin import ForAllExprMixin
-
-class ForAllExpr(ForAllExprMixin):
- pass
diff --git a/polymatrix/expression/impl/forallexprimpl.py b/polymatrix/expression/impl/linearinexprimpl.py
index 0960479..d97051f 100644
--- a/polymatrix/expression/impl/forallexprimpl.py
+++ b/polymatrix/expression/impl/linearinexprimpl.py
@@ -1,9 +1,9 @@
import dataclass_abc
-from polymatrix.expression.forallexpr import ForAllExpr
+from polymatrix.expression.linearinexpr import LinearInExpr
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@dataclass_abc.dataclass_abc(frozen=True)
-class ForAllExprImpl(ForAllExpr):
+class LinearInExprImpl(LinearInExpr):
underlying: ExpressionBaseMixin
variables: tuple
diff --git a/polymatrix/expression/impl/linearmatrixinexprimpl.py b/polymatrix/expression/impl/linearmatrixinexprimpl.py
new file mode 100644
index 0000000..b40e6a6
--- /dev/null
+++ b/polymatrix/expression/impl/linearmatrixinexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.linearmatrixinexpr import LinearMatrixInExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class LinearMatrixInExprImpl(LinearMatrixInExpr):
+ underlying: ExpressionBaseMixin
+ variable: int
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 9d0fc0d..525a697 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -5,14 +5,14 @@ from polymatrix.expression.impl.evalexprimpl import EvalExprImpl
def init_eval_expr(
underlying: ExpressionBaseMixin,
- values: tuple,
- variables: tuple = None,
+ variables: tuple,
+ values: tuple = None,
):
- if variables is None:
- assert isinstance(values, tuple)
+ if values is None:
+ assert isinstance(variables, tuple)
- variables, values = tuple(zip(*values))
+ variables, values = tuple(zip(*variables))
elif isinstance(values, np.ndarray):
values = tuple(values)
@@ -20,8 +20,8 @@ def init_eval_expr(
elif not isinstance(values, tuple):
values = (values,)
- if not isinstance(variables, tuple):
- variables = (variables,)
+ # if not isinstance(variables, tuple):
+ # variables = (variables,)
return EvalExprImpl(
underlying=underlying,
diff --git a/polymatrix/expression/init/initfromarrayexpr.py b/polymatrix/expression/init/initfromarrayexpr.py
index 6aab26c..e6f57c8 100644
--- a/polymatrix/expression/init/initfromarrayexpr.py
+++ b/polymatrix/expression/init/initfromarrayexpr.py
@@ -21,7 +21,7 @@ def init_from_array_expr(
assert all(len(col) == n_col for col in data)
case _:
- data = (data,)
+ data = tuple((e,) for e in data)
case _:
data = ((data,),)
diff --git a/polymatrix/expression/init/initforallexpr.py b/polymatrix/expression/init/initlinearinexpr.py
index 84388d2..f7f76e4 100644
--- a/polymatrix/expression/init/initforallexpr.py
+++ b/polymatrix/expression/init/initlinearinexpr.py
@@ -1,12 +1,12 @@
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.impl.forallexprimpl import ForAllExprImpl
+from polymatrix.expression.impl.linearinexprimpl import LinearInExprImpl
-def init_for_all_expr(
+def init_linear_in_expr(
underlying: ExpressionBaseMixin,
variables: tuple,
):
- return ForAllExprImpl(
+ return LinearInExprImpl(
underlying=underlying,
variables=variables,
)
diff --git a/polymatrix/expression/init/initlinearmatrixinexpr.py b/polymatrix/expression/init/initlinearmatrixinexpr.py
new file mode 100644
index 0000000..cd4ce97
--- /dev/null
+++ b/polymatrix/expression/init/initlinearmatrixinexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.linearmatrixinexprimpl import LinearMatrixInExprImpl
+
+
+def init_linear_matrix_in_expr(
+ underlying: ExpressionBaseMixin,
+ variable: int,
+):
+ return LinearMatrixInExprImpl(
+ underlying=underlying,
+ variable=variable,
+)
diff --git a/polymatrix/expression/linearinexpr.py b/polymatrix/expression/linearinexpr.py
new file mode 100644
index 0000000..4edf8b3
--- /dev/null
+++ b/polymatrix/expression/linearinexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
+
+class LinearInExpr(LinearInExprMixin):
+ pass
diff --git a/polymatrix/expression/linearmatrixinexpr.py b/polymatrix/expression/linearmatrixinexpr.py
new file mode 100644
index 0000000..2bce2e7
--- /dev/null
+++ b/polymatrix/expression/linearmatrixinexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin
+
+class LinearMatrixInExpr(LinearMatrixInExprMixin):
+ pass
diff --git a/polymatrix/expression/mixins/accumulateexprmixin.py b/polymatrix/expression/mixins/accumulateexprmixin.py
index 1f834bf..76e1717 100644
--- a/polymatrix/expression/mixins/accumulateexprmixin.py
+++ b/polymatrix/expression/mixins/accumulateexprmixin.py
@@ -24,7 +24,7 @@ class AccumulateExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index e068ba1..d274f2a 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -26,14 +26,14 @@ class AdditionExprMixin(ExpressionBaseMixin):
# return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- assert left.shape == right.shape
+ assert left.shape == right.shape, f'{left.shape} != {right.shape}'
terms = {}
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 43c0efe..a551c1f 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -4,8 +4,6 @@ import collections
import dataclasses
import itertools
import typing
-import dataclass_abc
-from numpy import nonzero, var
from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
@@ -31,51 +29,15 @@ class DerivativeExprMixin(ExpressionBaseMixin):
def introduce_derivatives(self) -> bool:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # match self.variables:
- # case ExpressionBaseMixin():
- # n_cols = self.variables.shape[0]
-
- # case _:
- # n_cols = len(self.variables)
-
- # return self.underlying.shape[0], n_cols
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- # match self.variables:
- # case ExpressionBaseMixin():
- # state, variables = self.variables.apply(state)
-
- # assert variables.shape[1] == 1
-
- # def gen_indices():
- # for row in range(variables.shape[0]):
- # row_terms = variables.get_poly(row, 0)
-
- # assert len(row_terms) == 1
-
- # for monomial in row_terms.keys():
- # assert len(monomial) == 1
- # yield monomial[0]
-
- # diff_wrt_variables = tuple(gen_indices())
-
- # case _:
- # def gen_indices():
- # for variable in self.variables:
- # if variable in state.offset_dict:
- # yield state.offset_dict[variable][0]
-
- # diff_wrt_variables = tuple(gen_indices())
+ state, underlying = self.underlying.apply(state=state)
- state, diff_wrt_variables = get_variable_indices(self.variables, state)
+ state, diff_wrt_variables = get_variable_indices(state, self.variables)
def get_derivative_terms(
monomial_terms,
@@ -181,8 +143,6 @@ class DerivativeExprMixin(ExpressionBaseMixin):
return state, dict(derivation_terms)
- state, underlying = self.underlying.apply(state=state)
-
terms = {}
for row in range(underlying.shape[0]):
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index f7f76eb..5b69d3f 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -22,7 +22,7 @@ class DeterminantExprMixin(ExpressionBaseMixin):
# return self.underlying.shape[0], 1
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 7b5e0c5..4b25ec6 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -25,7 +25,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
# return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 931e022..ca5d921 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -25,7 +25,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
# return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index e61cebd..4d78b26 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -26,15 +26,21 @@ class EvalExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(self.variables, state)
+ state, variable_indices = get_variable_indices(state, self.variables)
- assert len(variable_indices) == len(self.values)
+ if len(self.values) == 1:
+ values = tuple(self.values[0] for _ in variable_indices)
+
+ else:
+ assert len(variable_indices) == len(self.values)
+
+ values = self.values
terms = {}
@@ -55,7 +61,7 @@ class EvalExprMixin(ExpressionBaseMixin):
if variable in variable_indices:
index = variable_indices.index(variable)
- new_value = value * self.values[index]
+ new_value = value * values[index]
return monomial, new_value
else:
diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py
index 43b3fdb..baf605d 100644
--- a/polymatrix/expression/mixins/expressionbasemixin.py
+++ b/polymatrix/expression/mixins/expressionbasemixin.py
@@ -1,17 +1,20 @@
import abc
from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
class ExpressionBaseMixin(
- abc.ABC,
+ # StateMonad[ExpressionStateMixin, PolyMatrixMixin],
+ abc.ABC
):
- # @property
- # @abc.abstractclassmethod
- # def shape(self) -> tuple[int, int]:
- # ...
+ @abc.abstractmethod
+ def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ ...
@abc.abstractmethod
def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- ...
+ assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}'
+
+ return self._apply(state)
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index b43b3be..aa9ca17 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -10,9 +10,10 @@ from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr
from polymatrix.expression.init.initdivisionexpr import init_division_expr
from polymatrix.expression.init.initelemmultexpr import init_elem_mult_expr
from polymatrix.expression.init.initevalexpr import init_eval_expr
-from polymatrix.expression.init.initforallexpr import init_for_all_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
from polymatrix.expression.init.initgetitemexpr import init_get_item_expr
+from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr
from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr
from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize_terms_expr
from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
@@ -188,12 +189,21 @@ class ExpressionMixin(
def linear_in(self, variables: tuple) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_for_all_expr(
+ underlying=init_linear_in_expr(
underlying=self.underlying,
variables=variables,
),
)
+ def linear_matrix_in(self, variable) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_linear_matrix_in_expr(
+ underlying=self.underlying,
+ variable=variable,
+ ),
+ )
+
def quadratic_in(self, variables: tuple) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -213,15 +223,15 @@ class ExpressionMixin(
def eval(
self,
- values: tuple[float, ...],
- variables: tuple = None,
+ variable: tuple,
+ value: tuple[float, ...] = None,
) -> 'ExpressionMixin':
return dataclasses.replace(
self,
underlying=init_eval_expr(
underlying=self.underlying,
- variables=variables,
- values=values,
+ variables=variable,
+ values=value,
),
)
diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py
index 86f8b72..e79810c 100644
--- a/polymatrix/expression/mixins/fromarrayexprmixin.py
+++ b/polymatrix/expression/mixins/fromarrayexprmixin.py
@@ -15,13 +15,8 @@ class FromArrayExprMixin(ExpressionBaseMixin):
def data(self) -> tuple[tuple[float]]:
pass
- # # overwrites abstract method of `PolyMatrixExprBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return len(self.data), len(self.data[0])
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
index 4ada8ec..1af3f11 100644
--- a/polymatrix/expression/mixins/fromtermsexprmixin.py
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -21,7 +21,7 @@ class FromTermsExprMixin(ExpressionBaseMixin):
pass
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index 26ed728..c8ac02e 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -22,13 +22,8 @@ class GetItemExprMixin(ExpressionBaseMixin):
def index(self) -> tuple[int, int]:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return 1, 1
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py
index b4c9e72..20949a1 100644
--- a/polymatrix/expression/mixins/kktexprmixin.py
+++ b/polymatrix/expression/mixins/kktexprmixin.py
@@ -39,7 +39,7 @@ class KKTExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
@@ -170,7 +170,7 @@ class KKTExprMixin(ExpressionBaseMixin):
except KeyError:
continue
- # f(x) <= 0
+ # f(x) <= -0.01
terms[idx_start, 0] = underlying_terms | {(r_inequality, r_inequality): 1}
idx_start += 1
diff --git a/polymatrix/expression/mixins/forallexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index a988b24..557bed0 100644
--- a/polymatrix/expression/mixins/forallexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -9,7 +9,7 @@ from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
-class ForAllExprMixin(ExpressionBaseMixin):
+class LinearInExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -20,18 +20,15 @@ class ForAllExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.underlying.shape
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
+ # todo: uncomment this
+ # state, variable_indices = get_variable_indices(state, self.variables)
variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
terms = {}
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
new file mode 100644
index 0000000..321d955
--- /dev/null
+++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
@@ -0,0 +1,64 @@
+
+import abc
+import collections
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
+
+
+class LinearMatrixInExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variable(self) -> int:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def _apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ state, variable_indices = get_variable_indices(state, self.variable)
+
+ assert len(variable_indices) == 1
+
+ terms = {}
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ monomial_terms = collections.defaultdict(float)
+
+ for monomial, value in underlying_terms.items():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+
+ # only take linear terms
+ if len(x_monomial) == 1:
+ p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ monomial_terms[p_monomial] += value
+
+ terms[row, col] = dict(monomial_terms)
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=underlying.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index d1c96d4..f1e1720 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -19,13 +19,8 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return (self.left.shape[0], self.right.shape[1])
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
index 4d46ece..f20306c 100644
--- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
@@ -32,46 +32,46 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
# def shape(self) -> tuple[int, int]:
# return self.underlying.shape
- @dataclass_abc.dataclass_abc
- class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin):
- shape: tuple[int, int]
- terms: dict
- start_index: int
- n_param: int
+ # @dataclass_abc.dataclass_abc
+ # class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin):
+ # shape: tuple[int, int]
+ # terms: dict
+ # start_index: int
+ # n_param: int
- @property
- def param(self) -> tuple[int, int]:
- outer_self = self
+ # @property
+ # def param(self) -> tuple[int, int]:
+ # outer_self = self
- @dataclass_abc.dataclass_abc(frozen=True)
- class ParameterExpr(ExpressionBaseMixin):
- def apply(
- self,
- state: ExpressionStateMixin,
- ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ # @dataclass_abc.dataclass_abc(frozen=True)
+ # class ParameterExpr(ExpressionBaseMixin):
+ # def apply(
+ # self,
+ # state: ExpressionStateMixin,
+ # ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- state, poly_matrix = outer_self.apply(state)
+ # state, poly_matrix = outer_self.apply(state)
- n_param = poly_matrix.n_param
- start_index = poly_matrix.start_index
+ # n_param = poly_matrix.n_param
+ # start_index = poly_matrix.start_index
- def gen_monomials():
- for rel_index in range(n_param):
- yield {(start_index + rel_index,): 1}
+ # def gen_monomials():
+ # for rel_index in range(n_param):
+ # yield {(start_index + rel_index,): 1}
- terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())}
+ # terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())}
- poly_matrix = init_poly_matrix(
- terms=terms,
- shape=(n_param, 1),
- )
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=(n_param, 1),
+ # )
- return state, poly_matrix
+ # return state, poly_matrix
- return ParameterExpr()
+ # return ParameterExpr()
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
@@ -87,8 +87,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
- idx_start = state.n_param
- n_param = 0
+ start_index = state.n_param
terms = {}
for row in range(underlying.shape[0]):
@@ -99,32 +98,42 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
except KeyError:
continue
- terms_row_col = {}
- collected_terms = []
+ def gen_x_monomial_terms():
+ for monomial, value in underlying_terms.items():
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ yield monomial, x_monomial, value
- for monomial, value in underlying_terms.items():
+ x_monomial_terms = tuple(gen_x_monomial_terms())
- x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ collected_terms = tuple(sorted(set((x_monomial for _, x_monomial, _ in x_monomial_terms))))
- if x_monomial not in collected_terms:
- collected_terms.append(x_monomial)
-
- idx = idx_start + n_param + collected_terms.index(x_monomial)
+ terms_row_col = {}
- new_monomial = monomial + (idx,)
+ for monomial, x_monomial, value in x_monomial_terms:
+
+ new_monomial = monomial + (start_index + collected_terms.index(x_monomial),)
terms_row_col[new_monomial] = value
- n_param += len(collected_terms)
terms[row, col] = terms_row_col
- state = state.register(key=self, n_param=n_param)
+ start_index += len(collected_terms)
+
+ state = state.register(
+ key=self,
+ n_param=start_index - state.n_param,
+ )
+
+ # poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix(
+ # terms=terms,
+ # shape=underlying.shape,
+ # start_index=idx_start,
+ # n_param=n_param,
+ # )
- poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix(
+ poly_matrix = init_poly_matrix(
terms=terms,
shape=underlying.shape,
- start_index=idx_start,
- n_param=n_param,
)
state = dataclasses.replace(
diff --git a/polymatrix/expression/mixins/polymatrixmixin.py b/polymatrix/expression/mixins/polymatrixmixin.py
index c0dcac2..f83a62c 100644
--- a/polymatrix/expression/mixins/polymatrixmixin.py
+++ b/polymatrix/expression/mixins/polymatrixmixin.py
@@ -18,84 +18,84 @@ class PolyMatrixMixin(abc.ABC):
def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
...
- def get_equations(
- self,
- state: ExpressionStateMixin,
- ):
- assert self.shape[1] == 1
-
- def gen_used_variables():
- def gen_used_auxillary_variables(considered):
- monomial_terms = state.auxillary_equations[considered[-1]]
- for monomial in monomial_terms.keys():
- for variable in monomial:
- yield variable
-
- if variable not in considered and variable in state.auxillary_equations:
- yield from gen_used_auxillary_variables(considered + (variable,))
-
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
-
- try:
- underlying_terms = self.get_poly(row, col)
- except KeyError:
- continue
-
- for monomial in underlying_terms.keys():
- for variable in monomial:
- yield variable
-
- if variable in state.auxillary_equations:
- yield from gen_used_auxillary_variables((variable,))
-
- used_variables = set(gen_used_variables())
-
- ordered_variable_index = tuple(sorted(used_variables))
- variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
-
- n_param = len(ordered_variable_index)
-
- A = np.zeros((n_param, 1), dtype=np.float32)
- B = np.zeros((n_param, n_param), dtype=np.float32)
- C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32)
-
- def populate_matrices(monomial_terms, row):
- for monomial, value in monomial_terms.items():
- new_monomial = tuple(variable_index_map[var] for var in monomial)
- col = monomial_to_index(n_param, new_monomial)
-
- match len(new_monomial):
- case 0:
- A[row, col] = value
- case 1:
- B[row, col] = value
- case 2:
- C[row, col] = value
- case _:
- raise Exception(f'illegal case {new_monomial=}')
-
- for row in range(self.shape[0]):
- try:
- underlying_terms = self.get_poly(row, 0)
- except KeyError:
- continue
-
- populate_matrices(
- monomial_terms=underlying_terms,
- row=row,
- )
-
- current_row = self.shape[0]
-
- for key, monomial_terms in state.auxillary_equations.items():
- if key in ordered_variable_index:
- populate_matrices(
- monomial_terms=monomial_terms,
- row=current_row,
- )
- current_row += 1
-
- # assert current_row == n_param, f'{current_row} is not {n_param}'
-
- return (A, B, C), ordered_variable_index
+ # def get_equations(
+ # self,
+ # state: ExpressionStateMixin,
+ # ):
+ # assert self.shape[1] == 1
+
+ # def gen_used_variables():
+ # def gen_used_auxillary_variables(considered):
+ # monomial_terms = state.auxillary_equations[considered[-1]]
+ # for monomial in monomial_terms.keys():
+ # for variable in monomial:
+ # yield variable
+
+ # if variable not in considered and variable in state.auxillary_equations:
+ # yield from gen_used_auxillary_variables(considered + (variable,))
+
+ # for row in range(self.shape[0]):
+ # for col in range(self.shape[1]):
+
+ # try:
+ # underlying_terms = self.get_poly(row, col)
+ # except KeyError:
+ # continue
+
+ # for monomial in underlying_terms.keys():
+ # for variable in monomial:
+ # yield variable
+
+ # if variable in state.auxillary_equations:
+ # yield from gen_used_auxillary_variables((variable,))
+
+ # used_variables = set(gen_used_variables())
+
+ # ordered_variable_index = tuple(sorted(used_variables))
+ # variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
+
+ # n_param = len(ordered_variable_index)
+
+ # A = np.zeros((n_param, 1), dtype=np.float32)
+ # B = np.zeros((n_param, n_param), dtype=np.float32)
+ # C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32)
+
+ # def populate_matrices(monomial_terms, row):
+ # for monomial, value in monomial_terms.items():
+ # new_monomial = tuple(variable_index_map[var] for var in monomial)
+ # col = monomial_to_index(n_param, new_monomial)
+
+ # match len(new_monomial):
+ # case 0:
+ # A[row, col] = value
+ # case 1:
+ # B[row, col] = value
+ # case 2:
+ # C[row, col] = value
+ # case _:
+ # raise Exception(f'illegal case {new_monomial=}')
+
+ # for row in range(self.shape[0]):
+ # try:
+ # underlying_terms = self.get_poly(row, 0)
+ # except KeyError:
+ # continue
+
+ # populate_matrices(
+ # monomial_terms=underlying_terms,
+ # row=row,
+ # )
+
+ # current_row = self.shape[0]
+
+ # for key, monomial_terms in state.auxillary_equations.items():
+ # if key in ordered_variable_index:
+ # populate_matrices(
+ # monomial_terms=monomial_terms,
+ # row=current_row,
+ # )
+ # current_row += 1
+
+ # # assert current_row == n_param, f'{current_row} is not {n_param}'
+
+ # return (A, B, C), ordered_variable_index
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index 5aeef1b..ce6b5c2 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -24,7 +24,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
# return 2*(len(self.variables),)
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
@@ -49,7 +49,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
- assert len(x_monomial) == 2
+ assert len(x_monomial) == 2, f'{x_monomial} should be of length 2'
assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
key = tuple(reversed(x_monomial))
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
index 26777e9..5d5c00f 100644
--- a/polymatrix/expression/mixins/repmatexprmixin.py
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -16,12 +16,8 @@ class RepMatExprMixin(ExpressionBaseMixin):
def repetition(self) -> tuple[int, int]:
...
- # @property
- # def shape(self) -> tuple[int, int]:
- # return tuple(s*r for s, r in zip(self.underlying.shape, self.repetition))
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
index 3079e0f..0aac30a 100644
--- a/polymatrix/expression/mixins/toquadraticexprmixin.py
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -15,13 +15,8 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.underlying.shape
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
index d24cf53..fa074d7 100644
--- a/polymatrix/expression/mixins/transposeexprmixin.py
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -17,13 +17,8 @@ class TransposeExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `PolyMatrixExprBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.underlying.shape[1], self.underlying.shape[0]
-
# overwrites abstract method of `PolyMatrixExprBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index c192ae3..90cfb20 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -15,14 +15,8 @@ class VStackExprMixin(ExpressionBaseMixin):
def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # n_row = sum(expr.shape[0] for expr in self.underlying)
- # return n_row, self.underlying[0].shape[1]
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 9379e5f..46bb51b 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,7 +1,9 @@
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-def get_variable_indices(variables, state):
+def get_variable_indices(state, variables):
+
+ # print(f'{variables=}')
if isinstance(variables, ExpressionBaseMixin):
state, variables = variables.apply(state)
@@ -20,18 +22,20 @@ def get_variable_indices(variables, state):
return state, tuple(gen_indices())
- if not isinstance(variables, tuple):
- variables = (variables,)
+ else:
+
+ if not isinstance(variables, tuple):
+ variables = (variables,)
- # assert all(isinstance(variable, type(variables[0])) for variable in variables)
+ # assert all(isinstance(variable, type(variables[0])) for variable in variables)
- def gen_indices():
- for variable in variables:
+ def gen_indices():
+ for variable in variables:
- if isinstance(variable, int):
- yield variable
+ if isinstance(variable, int):
+ yield variable
- else:
- yield state.offset_dict[variable][0]
+ else:
+ yield state.offset_dict[variable][0]
- return state, tuple(gen_indices())
+ return state, tuple(gen_indices())
diff --git a/polymatrix/polysolver.py b/polymatrix/polysolver.py
index da6fd5d..8c055f8 100644
--- a/polymatrix/polysolver.py
+++ b/polymatrix/polysolver.py
@@ -153,10 +153,11 @@ def solve_poly_system(data, m):
else:
b_num_inv = scipy.sparse.linalg.inv(data[1])
- n = b_num_inv.shape[0]
- p = b_num_inv @ np.ones((n, 1))
+ n_var = b_num_inv.shape[0]
+ p0 = b_num_inv @ np.ones((n_var, 1))
- assert (data[0] + np.ones((n, 1)) < 0.01).all(), f'{data[0]=}, {data[0] + np.ones((n, 1))=}'
+ assert (data[0] + np.ones((n_var, 1)) < 0.01).all(), f'{data[0]=}, {data[0] + np.ones((n_var, 1))=}'
+ assert (data[0] + np.ones((n_var, 1)) < 0.01).all(), f'data[0] is not one'
def func(acc, _):
"""
@@ -182,76 +183,73 @@ def solve_poly_system(data, m):
n - (2*d-1)*idx,
d-1,
))
-
- def acc_kron(l):
- *_, last = itertools.accumulate(l, lambda acc, v: np.kron(acc, v))
- return last
def gen_p():
for degree, d_data in data.items():
if 1 < degree and degree-1 <= k:
indices_list = list(gen_indices(k-degree+1, degree-1))
- permutations = (perm for indices in indices_list for perm in more_itertools.distinct_permutations(indices))
+ permutations = lambda: (perm for indices in indices_list for perm in more_itertools.distinct_permutations(indices))
- if not scipy.sparse.issparse(data):
+ if not scipy.sparse.issparse(d_data):
def acc_kron(perm):
*_, last = itertools.accumulate(((p[:,idx:idx+1] for idx in perm)), lambda acc, v: np.kron(acc, v))
return last
- yield from (d_data @ acc_kron(perm) for perm in permutations)
+ yield from (d_data @ acc_kron(perm) for perm in permutations())
else:
- csr_data = d_data.tocsr()
-
- n_eye = sum(1 for idx in indices if idx == 0)
- n_col = n**n_eye
- n_index = n_col * n**(len(indices) - n_eye)
- def gen_array_per_permuatation():
+ csr_data = d_data.tocsr()
- def gen_coord(perm):
- n_row = csr_data.shape[0]
+ n_row = csr_data.shape[0]
+ n_col = csr_data.shape[1]
- def acc_col_idx_and_value(acc, v):
- relindex, relrow, val = acc
+ def gen_row_values():
+
+ def acc_kron_operation(acc, v):
+ relindex, relrow, val = acc
- n_relrow = relrow / n
- grp_index = int(relindex / n_relrow)
- n_relindex = int(relindex - grp_index * n_relrow)
- n_val = val * p[grp_index, v:v+1]
+ n_relrow = relrow / n_var
+ grp_index = int(relindex / n_relrow)
+ n_relindex = int(relindex - grp_index * n_relrow)
+ n_val = val * float(p[grp_index, v:v+1])
- return n_relindex, n_relrow, n_val
+ return n_relindex, n_relrow, n_val
- for row in range(n_row):
+ for row in range(n_row):
- pt = slice(csr_data.indptr[row], csr_data.indptr[row+1])
+ pt = slice(csr_data.indptr[row], csr_data.indptr[row+1])
- def gen_val_per_row():
- for inner_idx, array_val in zip(csr_data.indices[pt], csr_data.data[pt]):
+ def gen_row_multiplication():
+ for col_idx, col_val in zip(csr_data.indices[pt], csr_data.data[pt]):
+ for perm in permutations():
- *_, last = itertools.accumulate(
+ *_, (_, _, val) = itertools.accumulate(
perm,
- acc_col_idx_and_value,
- initial=(inner_idx, n_index, array_val),
+ acc_kron_operation,
+ initial=(col_idx, n_col, col_val),
)
- _, _, val = last
yield val
-
- yield sum(gen_val_per_row())
+
+ yield sum(gen_row_multiplication())
+
+ # for perm in permutations:
- for perm in permutations:
+ row_values = tuple(gen_row_values())
- array_values = list(gen_coord(perm))
+ assert len(row_values) == n_row
- yield scipy.sparse.csr_array((array_values, np.zeros(len(array_values)), csr_data.indptr), shape=(csr_data.shape[0], n_col))
+ yield scipy.sparse.coo_array(
+ (row_values, (range(n_row), np.zeros(n_row))),
+ shape=(n_row, 1),
+ )
return np.concatenate((p, -b_num_inv @ sum(gen_p())), axis=1)
- *_, sol_subs = itertools.accumulate(range(m-1), func, initial=p)
+ *_, sol_subs = itertools.accumulate(range(m-1), func, initial=p0)
- # return np.asarray(sol_subs)
return sum(np.asarray(sol_subs).T)
@@ -297,7 +295,7 @@ def inner_smart_solve(data, irange=None, idegree=None, a_init=None):
return a, err
- a, err = list(zip(*itertools.accumulate(
+ a, err = tuple(zip(*itertools.accumulate(
irange,
acc_func,
initial=(a_init, 0),
@@ -319,27 +317,23 @@ def outer_smart_solve(data, a_init=None, n_iter=10, a_max=1.0, irange=None, ideg
try:
a, err = inner_smart_solve(data, irange=irange, idegree=idegree, a_init=a_subs)
- subs_data = substitude_x_add_a(data, a[-1])
- sol = solve_poly_system(subs_data, 6)
+ # subs_data = substitude_x_add_a(data, a[-1])
+ # sol = solve_poly_system(subs_data, 6)
- error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) + max(a[-1])
+ # error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) + max(a[-1])
+ # error_index = np.max(np.abs(eval_solution(data, a[-1] + sol)))
# error_index = np.max(np.abs(eval_solution(subs_data, sol))) + max(a[-1])
# error_index = np.max(np.abs(eval_solution(subs_data, sol)))
except:
print(f'nan error, continue')
- # print(f'nan error for {a_init=}, continue')
- # yield np.nan, a_init, np.nan
continue
- print(f'{error_index=}')
+ # print(f'{error_index=}')
- yield error_index, a, err
-
- # if error_index < 1.0:
- # break
+ yield a, err
- _, a, err = min(gen_a_err())
+ a, err = tuple(zip(*gen_a_err()))
return a, err
@@ -359,6 +353,7 @@ def eval_solution(data, x=None):
# yield d_data @ last
if not scipy.sparse.issparse(d_data):
+
*_, last = itertools.accumulate(degree*(x,), lambda acc, v: np.kron(acc, v))
yield d_data @ last
diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py
new file mode 100644
index 0000000..33aa86a
--- /dev/null
+++ b/polymatrix/statemonad/__init__.py
@@ -0,0 +1,16 @@
+from polymatrix.statemonad.init.initstatemonad import init_state_monad
+from polymatrix.statemonad.statemonad import StateMonad
+
+
+def zip(monads: tuple[StateMonad]):
+
+ def zip_func(state):
+ values = tuple()
+
+ for monad in monads:
+ state, val = monad.apply(state)
+ values += (val,)
+
+ return state, values
+
+ return init_state_monad(zip_func)
diff --git a/polymatrix/statemonad/impl/__init__.py b/polymatrix/statemonad/impl/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/statemonad/impl/__init__.py
diff --git a/polymatrix/statemonad/impl/statemonadimpl.py b/polymatrix/statemonad/impl/statemonadimpl.py
new file mode 100644
index 0000000..5f8c6b8
--- /dev/null
+++ b/polymatrix/statemonad/impl/statemonadimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.statemonad.statemonad import StateMonad
+
+from typing import Callable
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class StateMonadImpl(StateMonad):
+ apply_func: Callable
diff --git a/polymatrix/statemonad/init/__init__.py b/polymatrix/statemonad/init/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/statemonad/init/__init__.py
diff --git a/polymatrix/statemonad/init/initstatemonad.py b/polymatrix/statemonad/init/initstatemonad.py
new file mode 100644
index 0000000..7d269f3
--- /dev/null
+++ b/polymatrix/statemonad/init/initstatemonad.py
@@ -0,0 +1,10 @@
+from typing import Callable
+from polymatrix.statemonad.impl.statemonadimpl import StateMonadImpl
+
+
+def init_state_monad(
+ apply_func: Callable,
+):
+ return StateMonadImpl(
+ apply_func=apply_func,
+)
diff --git a/polymatrix/statemonad/mixins/__init__.py b/polymatrix/statemonad/mixins/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/statemonad/mixins/__init__.py
diff --git a/polymatrix/statemonad/mixins/statemonadmixin.py b/polymatrix/statemonad/mixins/statemonadmixin.py
new file mode 100644
index 0000000..39b6576
--- /dev/null
+++ b/polymatrix/statemonad/mixins/statemonadmixin.py
@@ -0,0 +1,53 @@
+import abc
+import dataclasses
+from typing import Callable, Tuple, TypeVar, Generic
+import typing
+
+State = TypeVar('State')
+U = TypeVar('U')
+V = TypeVar('V')
+
+
+class StateMonadMixin(
+ Generic[State, U],
+ abc.ABC,
+):
+ @property
+ @abc.abstractmethod
+ def apply_func(self) -> typing.Callable[[State], tuple[State, U]]:
+ ...
+
+ # def init(func: Callable[[State], Tuple[State, U]]) -> 'StateMonadMixin[U, State]':
+ # class StateMonadImpl(StateMonadMixin):
+ # def apply(self, state: State) -> Tuple[State, U]:
+ # return func(state)
+
+ # return StateMonadImpl()
+
+ def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]':
+
+ def internal_map(state: State) -> Tuple[State, U]:
+ n_state, val = self.apply(state)
+ return n_state, fn(val)
+
+ return dataclasses.replace(self, apply_func=internal_map)
+
+ def flat_map(self, fn: Callable[[U], 'StateMonadMixin']) -> 'StateMonadMixin[State, V]':
+
+ def internal_map(state: State) -> Tuple[State, V]:
+ n_state, val = self.apply(state)
+ return fn(val).apply(n_state)
+
+ return dataclasses.replace(self, apply_func=internal_map)
+
+ def zip(self, other: 'StateMonadMixin') -> 'StateMonadMixin':
+ def internal_map(state: State) -> Tuple[State, V]:
+ state, val1 = self.apply(state)
+ state, val2 = other.apply(state)
+ return state, (val1, val2)
+
+ return dataclasses.replace(self, apply_func=internal_map)
+
+ # @abc.abstractmethod
+ def apply(self, state: State) -> Tuple[State, U]:
+ return self.apply_func(state)
diff --git a/polymatrix/statemonad/statemonad.py b/polymatrix/statemonad/statemonad.py
new file mode 100644
index 0000000..49ab1fa
--- /dev/null
+++ b/polymatrix/statemonad/statemonad.py
@@ -0,0 +1,4 @@
+from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
+
+class StateMonad(StateMonadMixin):
+ pass