summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py480
-rw-r--r--polymatrix/expression/blockdiagexpr.py4
-rw-r--r--polymatrix/expression/cacheexpr.py4
-rw-r--r--polymatrix/expression/impl/blockdiagexprimpl.py7
-rw-r--r--polymatrix/expression/impl/cacheexprimpl.py8
-rw-r--r--polymatrix/expression/impl/expressionstateimpl.py3
-rw-r--r--polymatrix/expression/impl/reshapeexprimpl.py9
-rw-r--r--polymatrix/expression/init/initblockdiagexpr.py9
-rw-r--r--polymatrix/expression/init/initcacheexpr.py10
-rw-r--r--polymatrix/expression/init/initexpressionstate.py2
-rw-r--r--polymatrix/expression/init/initfromtermsexpr.py13
-rw-r--r--polymatrix/expression/init/initreshapeexpr.py12
-rw-r--r--polymatrix/expression/mixins/addauxequationsexprmixin.py58
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py67
-rw-r--r--polymatrix/expression/mixins/cacheexprmixin.py40
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/expressionbasemixin.py3
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py35
-rw-r--r--polymatrix/expression/mixins/expressionstatemixin.py14
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/parametrizetermsexprmixin.py55
-rw-r--r--polymatrix/expression/mixins/polymatrixmixin.py95
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py12
-rw-r--r--polymatrix/expression/mixins/reshapeexprmixin.py75
-rw-r--r--polymatrix/expression/reshapeexpr.py4
-rw-r--r--polymatrix/statemonad/__init__.py9
-rw-r--r--polymatrix/statemonad/mixins/statemixin.py8
-rw-r--r--polymatrix/statemonad/mixins/statemonadmixin.py28
30 files changed, 788 insertions, 301 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 139e391..20b05a9 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -1,4 +1,7 @@
+import collections
+import dataclasses
import itertools
+import typing
import numpy as np
import scipy.sparse
# import polymatrix.statemonad
@@ -6,6 +9,7 @@ import scipy.sparse
from polymatrix.expression.expression import Expression
from polymatrix.expression.expressionstate import ExpressionState
from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
+from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
from polymatrix.expression.init.initexpression import init_expression
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
@@ -27,6 +31,14 @@ def from_(
)
+def from_polymatrix(
+ polymatrix: PolyMatrix,
+):
+ return init_expression(
+ init_from_terms_expr(polymatrix)
+ )
+
+
def accumulate(
expr,
func,
@@ -34,7 +46,6 @@ def accumulate(
):
def lifted_func(acc, polymat: PolyMatrix):
-
# # print(f'{terms=}')
# print(f'{terms=}')
@@ -62,103 +73,228 @@ def v_stack(
)
-def kkt(
- cost: Expression,
+def h_stack(
+ expressions: tuple[Expression],
+):
+ return init_expression(
+ init_v_stack_expr(tuple(expr.T for expr in expressions))
+ ).T
+
+
+def block_diag(
+ expressions: tuple[Expression],
+):
+ return init_expression(
+ init_block_diag_expr(expressions)
+ )
+
+
+# def kkt(
+# cost: Expression,
+# variables: Expression,
+# equality: Expression = None,
+# inequality: Expression = None,
+# ):
+# return init_expression(
+# init_kkt_expr(
+# cost=cost,
+# equality=equality,
+# variables=variables,
+# inequality=inequality,
+# )
+# )
+
+def kkt_equality(
variables: Expression,
equality: Expression = None,
+):
+
+ self_variables = variables
+ self_equality = equality
+
+ def func(state: ExpressionState):
+
+ state, equality = self_equality.apply(state=state)
+
+ state, equality_der = self_equality.diff(
+ self_variables,
+ introduce_derivatives=True,
+ ).apply(state)
+
+ def acc_nu_variables(acc, v):
+ state, nu_variables = acc
+
+ nu_variable = state.n_param
+ state = state.register(n_param=1)
+
+ return state, nu_variables + [nu_variable]
+
+ *_, (state, nu_variables) = tuple(itertools.accumulate(
+ range(equality.shape[0]),
+ acc_nu_variables,
+ initial=(state, []),
+ ))
+
+ terms = {}
+
+ for row in range(equality_der.shape[1]):
+
+ monomial_terms = collections.defaultdict(float)
+
+ for eq_idx, nu_variable in enumerate(nu_variables):
+
+ try:
+ underlying_terms = equality_der.get_poly(eq_idx, row)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (nu_variable,)
+
+ monomial_terms[new_monomial] += value
+
+ terms[row, 0] = dict(monomial_terms)
+
+ cost_expr = init_expression(init_from_terms_expr(
+ terms=terms,
+ shape=(equality_der.shape[1], 1),
+ ))
+
+ nu_terms = {}
+
+ for eq_idx, nu_variable in enumerate(nu_variables):
+ nu_terms[eq_idx, 0] = {(nu_variable,): 1}
+
+ nu_expr = init_expression(init_from_terms_expr(
+ terms=nu_terms,
+ shape=(len(nu_variables), 1),
+ ))
+
+ return state, (nu_expr, cost_expr)
+
+ return init_state_monad(func)
+
+def kkt_inequality(
+ variables: Expression,
inequality: Expression = None,
):
- return init_expression(
- init_kkt_expr(
- cost=cost,
- equality=equality,
- variables=variables,
- inequality=inequality,
- )
- )
- # self_cost = cost
- # self_variables = variables
- # self_equality = equality
+ self_variables = variables
+ self_inequality = inequality
- # def func(state: ExpressionState):
- # state, cost = self_cost.apply(state=state)
+ def func(state: ExpressionState):
- # assert cost.shape[1] == 1
+ state, inequality = self_inequality.apply(state=state)
- # if self_equality is not None:
+ state, inequality_der = self_inequality.diff(
+ self_variables,
+ introduce_derivatives=True,
+ ).apply(state)
- # state, equality = self_equality.apply(state=state)
+ def acc_lambda_variables(acc, v):
+ state, lambda_variables = acc
- # state, equality_der = self_equality.diff(
- # self_variables,
- # introduce_derivatives=True,
- # ).apply(state)
+ lambda_variable = state.n_param
+ state = state.register(n_param=3)
+
+ return state, lambda_variables + [lambda_variable]
- # assert cost.shape[0] == equality_der.shape[1]
+ *_, (state, lambda_variables) = tuple(itertools.accumulate(
+ range(inequality.shape[0]),
+ acc_lambda_variables,
+ initial=(state, []),
+ ))
- # def acc_nu_variables(acc, v):
- # state, nu_variables = acc
+ terms = {}
- # nu_variable = state.n_param
- # state = state.register(n_param=1)
-
- # return state, nu_variables + [nu_variable]
+ for row in range(inequality_der.shape[1]):
- # *_, (state, nu_variables) = tuple(itertools.accumulate(
- # range(equality.shape[0]),
- # acc_nu_variables,
- # initial=(state, []),
- # ))
+ monomial_terms = collections.defaultdict(float)
- # else:
- # nu_variables = tuple()
+ for inequality_idx, lambda_variable in enumerate(lambda_variables):
- # idx_start = 0
+ try:
+ underlying_terms = inequality_der.get_poly(inequality_idx, row)
+ except KeyError:
+ continue
- # terms = {}
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (lambda_variable,)
+
+ monomial_terms[new_monomial] += value
+
+ terms[row, 0] = dict(monomial_terms)
+
+ cost_expr = init_expression(init_from_terms_expr(
+ terms=terms,
+ shape=(inequality_der.shape[1], 1),
+ ))
+
+ # inequality, dual feasibility, complementary slackness
+ # -----------------------------------------------------
+
+ inequality_terms = {}
+ feasibility_terms = {}
+ complementary_terms = {}
+
+ for inequality_idx, lambda_variable in enumerate(lambda_variables):
+ r_lambda = lambda_variable + 1
+ r_inequality = lambda_variable + 2
- # for row in range(cost.shape[0]):
- # try:
- # monomial_terms = cost.get_poly(row, 0)
- # except KeyError:
- # monomial_terms = {}
+ try:
+ underlying_terms = inequality.get_poly(inequality_idx, 0)
+ except KeyError:
+ continue
+
+ # f(x) <= -0.01
+ inequality_terms[inequality_idx, 0] = underlying_terms | {(r_inequality, r_inequality): 1}
- # for eq_idx, nu_variable in enumerate(nu_variables):
+ # dual feasibility, lambda >= 0
+ feasibility_terms[inequality_idx, 0] = {(lambda_variable,): 1, (r_lambda, r_lambda): -1}
- # try:
- # underlying_terms = equality_der.get_poly(eq_idx, row)
- # except KeyError:
- # continue
+ # complementary slackness
+ complementary_terms[inequality_idx, 0] = {(r_lambda, r_inequality): 1}
- # for monomial, value in underlying_terms.items():
- # new_monomial = monomial + (nu_variable,)
+ inequality_expr = init_expression(init_from_terms_expr(
+ terms=inequality_terms,
+ shape=(len(lambda_variables), 1),
+ ))
- # if new_monomial not in monomial_terms:
- # monomial_terms[new_monomial] = 0
+ feasibility_expr = init_expression(init_from_terms_expr(
+ terms=feasibility_terms,
+ shape=(len(lambda_variables), 1),
+ ))
- # monomial_terms[new_monomial] += value
+ complementary_expr = init_expression(init_from_terms_expr(
+ terms=complementary_terms,
+ shape=(len(lambda_variables), 1),
+ ))
- # terms[idx_start, 0] = monomial_terms
- # idx_start += 1
+ # lambda expression
+ # -----------------
- # cost_expr = init_expression(init_from_terms_expr(
- # terms=terms,
- # shape=(idx_start, 1),
- # ))
+ terms = {}
+ for inequality_idx, lambda_variable in enumerate(lambda_variables):
+ terms[inequality_idx, 0] = {(lambda_variable,): 1}
- # terms = {}
- # for eq_idx, nu_variable in enumerate(nu_variables):
- # terms[eq_idx, 0] = {(nu_variable,): 1}
+ lambda_expr = init_expression(init_from_terms_expr(
+ terms=terms,
+ shape=(len(lambda_variables), 1),
+ ))
+
+ return state, (lambda_expr, cost_expr, inequality_expr, feasibility_expr, complementary_expr)
+
+ return init_state_monad(func)
- # nu_expr = init_expression(init_from_terms_expr(
- # terms=terms,
- # shape=(len(nu_variables), 1),
- # ))
+# def to_polymatrix(
+# expr: Expression,
+# ):
+# def func(state: ExpressionState):
+# state, polymatrix = expr.apply(state)
- # return state, (cost_expr, nu_expr)
+# return state, polymatrix
- # return StateMonadMixin.init(func)
+# return init_state_monad(func)
# def to_linear_matrix(
# expr: Expression,
@@ -181,13 +317,62 @@ def kkt(
# return StateMonad.init(func)
+@dataclasses.dataclass
+class MatrixEquations:
+ matrix_equations: tuple[tuple[np.ndarray, ...], ...]
+ auxillary_matrix_equations: typing.Optional[tuple[np.ndarray, ...]]
+ variable_index: tuple[int, ...]
+ state: ExpressionState
+
+ def merge_matrix_equations(self):
+ def gen_matrices(index: int):
+ for equations in self.matrix_equations:
+ if index < len(equations):
+ yield equations[index]
+
+ if index < len(self.auxillary_matrix_equations):
+ yield self.auxillary_matrix_equations[index]
+
+ matrix_1 = np.vstack(tuple(gen_matrices(0)))
+ matrix_2 = np.vstack(tuple(gen_matrices(1)))
+ matrix_3 = scipy.sparse.vstack(tuple(gen_matrices(2)))
+
+ return (matrix_1, matrix_2, matrix_3)
+
+ def get_value(self, variable, value):
+ if isinstance(variable, Expression):
+ variable = variable.underlying
+
+ offset = self.state.offset_dict[variable]
+ offset_idx = list(self.variable_index.index(idx) for idx in range(*offset))
+ return value[offset_idx]
+
+
def to_matrix_equations(
- expr: Expression,
-) -> StateMonadMixin[ExpressionState, tuple[tuple[np.ndarray, ...], tuple[int, ...]]]:
+ expr: tuple[Expression],
+) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
+
+ if isinstance(expr, Expression):
+ expr = (expr,)
+
def func(state: ExpressionState):
- state, underlying = expr.apply(state)
- assert underlying.shape[1] == 1
+ def acc_underlying(acc, v):
+ state, underlying_list = acc
+
+ state, underlying = v.apply(state)
+
+ assert underlying.shape[1] == 1, f'{underlying.shape[1]} is not 1'
+
+ return state, underlying_list + (underlying,)
+
+ *_, (state, underlying_list) = tuple(itertools.accumulate(
+ expr,
+ acc_underlying,
+ initial=(state, tuple()),
+ ))
+
+ # state, underlying = expr.apply(state)
def gen_used_variables():
def gen_used_auxillary_variables(considered):
@@ -199,73 +384,126 @@ def to_matrix_equations(
if variable not in considered and variable in state.auxillary_equations:
yield from gen_used_auxillary_variables(considered + (variable,))
- for row in range(underlying.shape[0]):
- for col in range(underlying.shape[1]):
+ for underlying in underlying_list:
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
- for monomial in underlying_terms.keys():
- for variable in monomial:
- yield variable
+ for monomial in underlying_terms.keys():
+ for variable in monomial:
+ yield variable
- if variable in state.auxillary_equations:
- yield from gen_used_auxillary_variables((variable,))
+ if variable in state.auxillary_equations:
+ yield from gen_used_auxillary_variables((variable,))
used_variables = set(gen_used_variables())
-
ordered_variable_index = tuple(sorted(used_variables))
variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
n_param = len(ordered_variable_index)
- A = np.zeros((n_param, 1), dtype=np.float32)
- B = np.zeros((n_param, n_param), dtype=np.float32)
- C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32)
+ def gen_underlying_matrices():
+ for underlying in underlying_list:
+ n_row = underlying.shape[0]
+
+ A = np.zeros((n_row, 1), dtype=np.double)
+ B = np.zeros((n_row, n_param), dtype=np.double)
+ C = scipy.sparse.dok_array((n_row, n_param**2), dtype=np.double)
+
+ # def populate_matrices(monomial_terms, row):
+ # for monomial, value in monomial_terms.items():
+ # new_monomial = tuple(variable_index_map[var] for var in monomial)
+ # col = monomial_to_index(n_param, new_monomial)
+
+ # match len(new_monomial):
+ # case 0:
+ # A[row, col] = value
+ # case 1:
+ # B[row, col] = value
+ # case 2:
+ # C[row, col] = value
+ # case _:
+ # raise Exception(f'illegal case {new_monomial=}')
+
+ for row in range(underlying.shape[0]):
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
- def populate_matrices(monomial_terms, row):
- for monomial, value in monomial_terms.items():
- new_monomial = tuple(variable_index_map[var] for var in monomial)
- col = monomial_to_index(n_param, new_monomial)
+ # populate_matrices(
+ # monomial_terms=underlying_terms,
+ # row=row,
+ # )
+
+ for monomial, value in underlying_terms.items():
+ new_monomial = tuple(variable_index_map[var] for var in monomial)
+ col = monomial_to_index(n_param, new_monomial)
+
+ match len(new_monomial):
+ case 0:
+ A[row, col] = value
+ case 1:
+ B[row, col] = value
+ case 2:
+ C[row, col] = value
+ case _:
+ raise Exception(f'illegal case {new_monomial=}')
+
+ yield A, B, C
- match len(new_monomial):
- case 0:
- A[row, col] = value
- case 1:
- B[row, col] = value
- case 2:
- C[row, col] = value
- case _:
- raise Exception(f'illegal case {new_monomial=}')
+ underlying_matrices = tuple(gen_underlying_matrices())
- for row in range(underlying.shape[0]):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
- continue
+ # current_row = underlying.shape[0]
+
+ def gen_auxillary_equations():
+ for key, monomial_terms in state.auxillary_equations.items():
+ if key in ordered_variable_index:
+ yield key, monomial_terms
- populate_matrices(
- monomial_terms=underlying_terms,
- row=row,
- )
+ auxillary_equations = tuple(gen_auxillary_equations())
- current_row = underlying.shape[0]
+ n_row = len(auxillary_equations)
- for key, monomial_terms in state.auxillary_equations.items():
- if key in ordered_variable_index:
- populate_matrices(
- monomial_terms=monomial_terms,
- row=current_row,
- )
- current_row += 1
+ if n_row == 0:
+ auxillary_matrix_equations = None
- # assert current_row == n_param, f'{current_row} is not {n_param}'
+ else:
+ A = np.zeros((n_row, 1), dtype=np.double)
+ B = np.zeros((n_row, n_param), dtype=np.double)
+ C = scipy.sparse.dok_array((n_row, n_param**2), dtype=np.double)
- return state, ((A, B, C), ordered_variable_index)
+ for row, (key, monomial_terms) in enumerate(auxillary_equations):
+ for monomial, value in monomial_terms.items():
+ new_monomial = tuple(variable_index_map[var] for var in monomial)
+ col = monomial_to_index(n_param, new_monomial)
- return StateMonadMixin.init(func)
+ match len(new_monomial):
+ case 0:
+ A[row, col] = value
+ case 1:
+ B[row, col] = value
+ case 2:
+ C[row, col] = value
+ case _:
+ raise Exception(f'illegal case {new_monomial=}')
+
+ auxillary_matrix_equations = (A, B, C)
+
+ result = MatrixEquations(
+ matrix_equations=underlying_matrices,
+ auxillary_matrix_equations=auxillary_matrix_equations,
+ variable_index=ordered_variable_index,
+ state=state,
+ )
+
+ return state, result
+
+ return init_state_monad(func)
def to_constant_matrix(
expr: Expression,
diff --git a/polymatrix/expression/blockdiagexpr.py b/polymatrix/expression/blockdiagexpr.py
new file mode 100644
index 0000000..e3acee5
--- /dev/null
+++ b/polymatrix/expression/blockdiagexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
+
+class BlockDiagExpr(BlockDiagExprMixin):
+ pass
diff --git a/polymatrix/expression/cacheexpr.py b/polymatrix/expression/cacheexpr.py
new file mode 100644
index 0000000..5ae4052
--- /dev/null
+++ b/polymatrix/expression/cacheexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin
+
+class CacheExpr(CacheExprMixin):
+ pass
diff --git a/polymatrix/expression/impl/blockdiagexprimpl.py b/polymatrix/expression/impl/blockdiagexprimpl.py
new file mode 100644
index 0000000..a2707d8
--- /dev/null
+++ b/polymatrix/expression/impl/blockdiagexprimpl.py
@@ -0,0 +1,7 @@
+import dataclass_abc
+from polymatrix.expression.blockdiagexpr import BlockDiagExpr
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class BlockDiagExprImpl(BlockDiagExpr):
+ underlying: tuple
diff --git a/polymatrix/expression/impl/cacheexprimpl.py b/polymatrix/expression/impl/cacheexprimpl.py
new file mode 100644
index 0000000..a6a74a1
--- /dev/null
+++ b/polymatrix/expression/impl/cacheexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.cacheexpr import CacheExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class CacheExprImpl(CacheExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/expressionstateimpl.py b/polymatrix/expression/impl/expressionstateimpl.py
index 7513970..5459eb7 100644
--- a/polymatrix/expression/impl/expressionstateimpl.py
+++ b/polymatrix/expression/impl/expressionstateimpl.py
@@ -1,3 +1,4 @@
+from functools import cached_property
import dataclass_abc
from polymatrix.expression.expressionstate import ExpressionState
@@ -8,4 +9,4 @@ class ExpressionStateImpl(ExpressionState):
n_param: int
offset_dict: dict
auxillary_equations: dict[int, dict[tuple[int], float]]
- cached_polymatrix: dict
+ cache: dict
diff --git a/polymatrix/expression/impl/reshapeexprimpl.py b/polymatrix/expression/impl/reshapeexprimpl.py
new file mode 100644
index 0000000..7c16f16
--- /dev/null
+++ b/polymatrix/expression/impl/reshapeexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.reshapeexpr import ReshapeExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ReshapeExprImpl(ReshapeExpr):
+ underlying: ExpressionBaseMixin
+ new_shape: tuple
diff --git a/polymatrix/expression/init/initblockdiagexpr.py b/polymatrix/expression/init/initblockdiagexpr.py
new file mode 100644
index 0000000..930385f
--- /dev/null
+++ b/polymatrix/expression/init/initblockdiagexpr.py
@@ -0,0 +1,9 @@
+from polymatrix.expression.impl.blockdiagexprimpl import BlockDiagExprImpl
+
+
+def init_block_diag_expr(
+ underlying: tuple,
+):
+ return BlockDiagExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initcacheexpr.py b/polymatrix/expression/init/initcacheexpr.py
new file mode 100644
index 0000000..977e033
--- /dev/null
+++ b/polymatrix/expression/init/initcacheexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.cacheexprimpl import CacheExprImpl
+
+
+def init_cache_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return CacheExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py
index 7e8a6fe..be2e4f1 100644
--- a/polymatrix/expression/init/initexpressionstate.py
+++ b/polymatrix/expression/init/initexpressionstate.py
@@ -15,5 +15,5 @@ def init_expression_state(
n_param=n_param,
offset_dict=offset_dict,
auxillary_equations={},
- cached_polymatrix={},
+ cache={},
)
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
index 80d5198..1abac12 100644
--- a/polymatrix/expression/init/initfromtermsexpr.py
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -1,10 +1,19 @@
+import typing
from polymatrix.expression.impl.fromtermsexprimpl import FromTermsExprImpl
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
def init_from_terms_expr(
- terms: tuple,
- shape: tuple[int, int]
+ terms: typing.Union[tuple, PolyMatrixMixin],
+ shape: tuple[int, int] = None,
):
+ if isinstance(terms, PolyMatrixMixin):
+ shape = terms.shape
+ terms = terms.get_terms()
+
+ else:
+ assert shape is not None
+
if isinstance(terms, dict):
terms = tuple((key, tuple(value.items())) for key, value in terms.items())
diff --git a/polymatrix/expression/init/initreshapeexpr.py b/polymatrix/expression/init/initreshapeexpr.py
new file mode 100644
index 0000000..f95fb00
--- /dev/null
+++ b/polymatrix/expression/init/initreshapeexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.reshapeexprimpl import ReshapeExprImpl
+
+
+def init_reshape_expr(
+ underlying: ExpressionBaseMixin,
+ new_shape: tuple,
+):
+ return ReshapeExprImpl(
+ underlying=underlying,
+ new_shape=new_shape,
+)
diff --git a/polymatrix/expression/mixins/addauxequationsexprmixin.py b/polymatrix/expression/mixins/addauxequationsexprmixin.py
new file mode 100644
index 0000000..1dae4f7
--- /dev/null
+++ b/polymatrix/expression/mixins/addauxequationsexprmixin.py
@@ -0,0 +1,58 @@
+
+import abc
+import dataclasses
+import itertools
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+# is this really needed?
+class AddAuxEquationsExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def _apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[1] == 1
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class AddAuxEquationsPolyMatrix(PolyMatrixMixin):
+ underlying: tuple[PolyMatrixMixin]
+ shape: tuple[int, int]
+ n_row: int
+ auxillary_equations: tuple[dict[tuple[int], float]]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ if row < self.n_row:
+ return self.underlying.get_poly(row, col)
+
+ elif row < self.shape[0]:
+ return self.auxillary_equations[row - self.n_row]
+
+ else:
+ raise Exception(f'row {row} is out of bounds')
+
+ auxillary_equations = tuple(state.auxillary_equations.values())
+
+ polymat = AddAuxEquationsPolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0] + len(auxillary_equations), 1),
+ n_row=underlying.shape[0],
+ auxillary_equations=auxillary_equations,
+ )
+
+ state = dataclasses.replace(state, auxillary_equations={})
+
+ return state, polymat
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
new file mode 100644
index 0000000..6bdbce1
--- /dev/null
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -0,0 +1,67 @@
+
+import abc
+import itertools
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+class BlockDiagExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def _apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+
+ all_underlying = []
+ for expr in self.underlying:
+ state, polymat = expr.apply(state=state)
+ all_underlying.append(polymat)
+
+ # assert all(underlying.shape[0] == underlying.shape[1] for underlying in all_underlying)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class BlockDiagPolyMatrix(PolyMatrixMixin):
+ all_underlying: tuple[PolyMatrixMixin]
+ underlying_row_col_range: tuple[tuple[int, int], ...]
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ for polymat, ((row_start, col_start), (row_end, col_end)) in zip(self.all_underlying, self.underlying_row_col_range):
+
+ if row_start <= row < row_end:
+ if col_start <= col < col_end:
+ return polymat.get_poly(
+ row=row-row_start,
+ col=col-col_start,
+ )
+
+ else:
+ raise KeyError()
+
+ raise Exception(f'row {row} is out of bounds')
+
+ underlying_row_col_range = tuple(itertools.pairwise(
+ itertools.accumulate(
+ (expr.shape for expr in all_underlying),
+ lambda acc, v: tuple(v1+v2 for v1, v2 in zip(acc, v)),
+ initial=(0, 0))
+ ))
+
+ shape = underlying_row_col_range[-1][1]
+
+ polymat = BlockDiagPolyMatrix(
+ all_underlying=all_underlying,
+ shape=shape,
+ underlying_row_col_range=underlying_row_col_range,
+ )
+
+ return state, polymat
diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py
new file mode 100644
index 0000000..8779ba1
--- /dev/null
+++ b/polymatrix/expression/mixins/cacheexprmixin.py
@@ -0,0 +1,40 @@
+
+import abc
+import dataclasses
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class CacheExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def _apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ if self in state.cache:
+ return state, state.cache[self]
+
+ state, underlying = self.underlying.apply(state)
+
+ cached_terms = dict(underlying.get_terms())
+
+ poly_matrix = init_poly_matrix(
+ terms=cached_terms,
+ shape=underlying.shape,
+ )
+
+ state = dataclasses.replace(
+ state,
+ cache=state.cache | {self: poly_matrix},
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index 5b69d3f..5436ed6 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -26,8 +26,8 @@ class DeterminantExprMixin(ExpressionBaseMixin):
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- if self in state.cached_polymatrix:
- return state, state.cached_polymatrix[self]
+ if self in state.cache:
+ return state, state.cache[self]
state, underlying = self.underlying.apply(state=state)
@@ -110,7 +110,7 @@ class DeterminantExprMixin(ExpressionBaseMixin):
state = dataclasses.replace(
state,
auxillary_equations=state.auxillary_equations | auxillary_equations,
- cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ cache=state.cache | {self: poly_matrix},
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 4b25ec6..cfb6bea 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -30,8 +30,8 @@ class DivisionExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- if self in state.cached_polymatrix:
- return state, state.cached_polymatrix[self]
+ if self in state.cache:
+ return state, state.cache[self]
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
@@ -76,7 +76,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
state = dataclasses.replace(
state,
auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms},
- cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ cached_polymatrix=state.cache | {self: poly_matrix},
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index 4d78b26..d4fdbea 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -73,13 +73,15 @@ class EvalExprMixin(ExpressionBaseMixin):
initial=(tuple(), value),
))
- # print(f'{new_monomial=}')
-
if new_monomial not in terms_row_col:
terms_row_col[new_monomial] = 0
terms_row_col[new_monomial] += new_value
+ # delete zero entries
+ if terms_row_col[new_monomial] == 0:
+ del terms_row_col[new_monomial]
+
terms[row, col] = terms_row_col
poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py
index baf605d..3825ed1 100644
--- a/polymatrix/expression/mixins/expressionbasemixin.py
+++ b/polymatrix/expression/mixins/expressionbasemixin.py
@@ -1,11 +1,9 @@
import abc
from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
-from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
class ExpressionBaseMixin(
- # StateMonad[ExpressionStateMixin, PolyMatrixMixin],
abc.ABC
):
@@ -13,7 +11,6 @@ class ExpressionBaseMixin(
def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
...
- @abc.abstractmethod
def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}'
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index aa9ca17..4cb4a50 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -5,6 +5,7 @@ import numpy as np
from sympy import re
from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initadditionexpr import init_addition_expr
+from polymatrix.expression.init.initcacheexpr import init_cache_expr
from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr
from polymatrix.expression.init.initdivisionexpr import init_division_expr
@@ -18,6 +19,7 @@ from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr
from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize_terms_expr
from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
+from polymatrix.expression.init.initreshapeexpr import init_reshape_expr
from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr
from polymatrix.expression.init.inittransposeexpr import init_transpose_expr
@@ -37,7 +39,7 @@ class ExpressionMixin(
...
# overwrites abstract method of `PolyMatrixExprBaseMixin`
- def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ def _apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
return self.underlying.apply(state)
# # overwrites abstract method of `PolyMatrixExprBaseMixin`
@@ -45,11 +47,6 @@ class ExpressionMixin(
# def degrees(self) -> set[int]:
# return self.underlying.degrees
- # # overwrites abstract method of `PolyMatrixExprBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.underlying.shape
-
# def __iter__(self):
# for row in range(self.shape[0]):
# yield self[row, 0]
@@ -93,7 +90,11 @@ class ExpressionMixin(
def __mul__(self, other) -> 'ExpressionMixin':
# assert isinstance(other, float)
- right = init_from_array_expr(other)
+ match other:
+ case ExpressionBaseMixin():
+ right = other.underlying
+ case _:
+ right = init_from_array_expr(other)
return dataclasses.replace(
self,
@@ -106,6 +107,9 @@ class ExpressionMixin(
def __rmul__(self, other):
return self * other
+ def __neg__(self):
+ return self * (-1)
+
def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
match other:
case ExpressionBaseMixin():
@@ -153,6 +157,14 @@ class ExpressionMixin(
),
)
+ def cache(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_cache_expr(
+ underlying=self.underlying,
+ ),
+ )
+
def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -163,6 +175,15 @@ class ExpressionMixin(
),
)
+ def reshape(self, n: int, m: int) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_reshape_expr(
+ underlying=self.underlying,
+ new_shape=(n, m),
+ ),
+ )
+
def rep_mat(self, n: int, m: int) -> 'ExpressionMixin':
return dataclasses.replace(
self,
diff --git a/polymatrix/expression/mixins/expressionstatemixin.py b/polymatrix/expression/mixins/expressionstatemixin.py
index 526b422..523d2b7 100644
--- a/polymatrix/expression/mixins/expressionstatemixin.py
+++ b/polymatrix/expression/mixins/expressionstatemixin.py
@@ -5,8 +5,12 @@ import typing
import sympy
+from polymatrix.statemonad.mixins.statemixin import StateCacheMixin
-class ExpressionStateMixin(abc.ABC):
+
+class ExpressionStateMixin(
+ StateCacheMixin,
+):
@property
@abc.abstractmethod
@@ -27,10 +31,10 @@ class ExpressionStateMixin(abc.ABC):
def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
...
- @property
- @abc.abstractmethod
- def cached_polymatrix(self) -> dict:
- ...
+ # @property
+ # @abc.abstractmethod
+ # def cache(self) -> dict:
+ # ...
def register(
self,
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index 557bed0..9ea4e92 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -7,6 +7,7 @@ from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
class LinearInExprMixin(ExpressionBaseMixin):
@@ -28,8 +29,8 @@ class LinearInExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state=state)
# todo: uncomment this
- # state, variable_indices = get_variable_indices(state, self.variables)
- variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+ state, variable_indices = get_variable_indices(state, self.variables)
+ # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
terms = {}
idx_row = 0
@@ -49,6 +50,8 @@ class LinearInExprMixin(ExpressionBaseMixin):
x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+ assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+
x_monomial_terms[x_monomial][p_monomial] += value
for data in x_monomial_terms.values():
diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
index f20306c..255a5da 100644
--- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
@@ -28,56 +28,14 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.underlying.shape
-
- # @dataclass_abc.dataclass_abc
- # class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin):
- # shape: tuple[int, int]
- # terms: dict
- # start_index: int
- # n_param: int
-
- # @property
- # def param(self) -> tuple[int, int]:
- # outer_self = self
-
- # @dataclass_abc.dataclass_abc(frozen=True)
- # class ParameterExpr(ExpressionBaseMixin):
- # def apply(
- # self,
- # state: ExpressionStateMixin,
- # ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
-
- # state, poly_matrix = outer_self.apply(state)
-
- # n_param = poly_matrix.n_param
- # start_index = poly_matrix.start_index
-
- # def gen_monomials():
- # for rel_index in range(n_param):
- # yield {(start_index + rel_index,): 1}
-
- # terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())}
-
- # poly_matrix = init_poly_matrix(
- # terms=terms,
- # shape=(n_param, 1),
- # )
-
- # return state, poly_matrix
-
- # return ParameterExpr()
-
# overwrites abstract method of `ExpressionBaseMixin`
def _apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- if self in state.cached_polymatrix:
- return state, state.cached_polymatrix[self]
+ if self in state.cache:
+ return state, state.cache[self]
# if not hasattr(self, '_terms'):
state, underlying = self.underlying.apply(state)
@@ -124,13 +82,6 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
n_param=start_index - state.n_param,
)
- # poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix(
- # terms=terms,
- # shape=underlying.shape,
- # start_index=idx_start,
- # n_param=n_param,
- # )
-
poly_matrix = init_poly_matrix(
terms=terms,
shape=underlying.shape,
@@ -138,7 +89,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
state = dataclasses.replace(
state,
- cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ cache=state.cache | {self: poly_matrix},
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/polymatrixmixin.py b/polymatrix/expression/mixins/polymatrixmixin.py
index f83a62c..89e05cb 100644
--- a/polymatrix/expression/mixins/polymatrixmixin.py
+++ b/polymatrix/expression/mixins/polymatrixmixin.py
@@ -14,88 +14,19 @@ class PolyMatrixMixin(abc.ABC):
def shape(self) -> tuple[int, int]:
...
+ def get_terms(self) -> tuple[tuple[int, int], dict[tuple[int, ...], float]]:
+ def gen_terms():
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+ try:
+ monomial_terms = self.get_poly(row, col)
+ except KeyError:
+ continue
+
+ yield (row, col), monomial_terms
+
+ return tuple(gen_terms())
+
@abc.abstractclassmethod
def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
...
-
- # def get_equations(
- # self,
- # state: ExpressionStateMixin,
- # ):
- # assert self.shape[1] == 1
-
- # def gen_used_variables():
- # def gen_used_auxillary_variables(considered):
- # monomial_terms = state.auxillary_equations[considered[-1]]
- # for monomial in monomial_terms.keys():
- # for variable in monomial:
- # yield variable
-
- # if variable not in considered and variable in state.auxillary_equations:
- # yield from gen_used_auxillary_variables(considered + (variable,))
-
- # for row in range(self.shape[0]):
- # for col in range(self.shape[1]):
-
- # try:
- # underlying_terms = self.get_poly(row, col)
- # except KeyError:
- # continue
-
- # for monomial in underlying_terms.keys():
- # for variable in monomial:
- # yield variable
-
- # if variable in state.auxillary_equations:
- # yield from gen_used_auxillary_variables((variable,))
-
- # used_variables = set(gen_used_variables())
-
- # ordered_variable_index = tuple(sorted(used_variables))
- # variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
-
- # n_param = len(ordered_variable_index)
-
- # A = np.zeros((n_param, 1), dtype=np.float32)
- # B = np.zeros((n_param, n_param), dtype=np.float32)
- # C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32)
-
- # def populate_matrices(monomial_terms, row):
- # for monomial, value in monomial_terms.items():
- # new_monomial = tuple(variable_index_map[var] for var in monomial)
- # col = monomial_to_index(n_param, new_monomial)
-
- # match len(new_monomial):
- # case 0:
- # A[row, col] = value
- # case 1:
- # B[row, col] = value
- # case 2:
- # C[row, col] = value
- # case _:
- # raise Exception(f'illegal case {new_monomial=}')
-
- # for row in range(self.shape[0]):
- # try:
- # underlying_terms = self.get_poly(row, 0)
- # except KeyError:
- # continue
-
- # populate_matrices(
- # monomial_terms=underlying_terms,
- # row=row,
- # )
-
- # current_row = self.shape[0]
-
- # for key, monomial_terms in state.auxillary_equations.items():
- # if key in ordered_variable_index:
- # populate_matrices(
- # monomial_terms=monomial_terms,
- # row=current_row,
- # )
- # current_row += 1
-
- # # assert current_row == n_param, f'{current_row} is not {n_param}'
-
- # return (A, B, C), ordered_variable_index
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index ce6b5c2..1fff79b 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -5,6 +5,7 @@ from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
class QuadraticInExprMixin(ExpressionBaseMixin):
@@ -18,11 +19,6 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return 2*(len(self.variables),)
-
# overwrites abstract method of `ExpressionBaseMixin`
def _apply(
self,
@@ -32,7 +28,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
assert underlying.shape == (1, 1)
- variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+ state, variable_indices = get_variable_indices(state, self.variables)
terms = {}
@@ -46,7 +42,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
for monomial, value in underlying_terms.items():
- x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ x_monomial = tuple(variable_indices.index(var_idx) for var_idx in monomial if var_idx in variable_indices)
p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
assert len(x_monomial) == 2, f'{x_monomial} should be of length 2'
@@ -63,7 +59,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
monomial_terms[p_monomial] = 0
monomial_terms[p_monomial] += value
-
+
poly_matrix = init_poly_matrix(
terms=terms,
shape=2*(len(self.variables),),
diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py
new file mode 100644
index 0000000..886f074
--- /dev/null
+++ b/polymatrix/expression/mixins/reshapeexprmixin.py
@@ -0,0 +1,75 @@
+import abc
+import functools
+import operator
+import dataclass_abc
+import numpy as np
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+class ReshapeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractclassmethod
+ def new_shape(self) -> tuple[int, int]:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def _apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class ReshapePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+ underlying_shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ index = row + self.shape[0] * col
+
+ underlying_col = int(index / self.underlying_shape[0])
+ underlying_row = index - underlying_col * self.underlying_shape[0]
+
+ # print(f'{row=}, {col=}')
+ # print(f'{underlying_row=}, {underlying_col=}')
+
+ return self.underlying.get_poly(underlying_row, underlying_col)
+
+ # replace '-1' by the remaining number of elements
+ if -1 in self.new_shape:
+ n_total = underlying.shape[0] * underlying.shape[1]
+
+ remaining_shape = tuple(e for e in self.new_shape if e != -1)
+
+ assert len(remaining_shape) + 1 == len(self.new_shape)
+
+ n_used = functools.reduce(operator.mul, remaining_shape)
+
+ n_remaining = int(n_total / n_used)
+
+ def gen_shape():
+ for e in self.new_shape:
+ if e == -1:
+ yield n_remaining
+ else:
+ yield e
+
+ new_shape = tuple(gen_shape())
+
+ else:
+ new_shape = self.new_shape
+
+ return state, ReshapePolyMatrix(
+ underlying=underlying,
+ shape=new_shape,
+ underlying_shape=underlying.shape,
+ ) \ No newline at end of file
diff --git a/polymatrix/expression/reshapeexpr.py b/polymatrix/expression/reshapeexpr.py
new file mode 100644
index 0000000..01ea7dd
--- /dev/null
+++ b/polymatrix/expression/reshapeexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin
+
+class ReshapeExpr(ReshapeExprMixin):
+ pass
diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py
index 33aa86a..bb280e0 100644
--- a/polymatrix/statemonad/__init__.py
+++ b/polymatrix/statemonad/__init__.py
@@ -2,6 +2,13 @@ from polymatrix.statemonad.init.initstatemonad import init_state_monad
from polymatrix.statemonad.statemonad import StateMonad
+def from_(val):
+ def func(state):
+ return state, val
+
+ return init_state_monad(func)
+
+
def zip(monads: tuple[StateMonad]):
def zip_func(state):
@@ -14,3 +21,5 @@ def zip(monads: tuple[StateMonad]):
return state, values
return init_state_monad(zip_func)
+
+
diff --git a/polymatrix/statemonad/mixins/statemixin.py b/polymatrix/statemonad/mixins/statemixin.py
new file mode 100644
index 0000000..6e2855b
--- /dev/null
+++ b/polymatrix/statemonad/mixins/statemixin.py
@@ -0,0 +1,8 @@
+import abc
+
+
+class StateCacheMixin(abc.ABC):
+ @property
+ @abc.abstractmethod
+ def cache(self) -> dict:
+ ...
diff --git a/polymatrix/statemonad/mixins/statemonadmixin.py b/polymatrix/statemonad/mixins/statemonadmixin.py
index 39b6576..c367708 100644
--- a/polymatrix/statemonad/mixins/statemonadmixin.py
+++ b/polymatrix/statemonad/mixins/statemonadmixin.py
@@ -3,7 +3,9 @@ import dataclasses
from typing import Callable, Tuple, TypeVar, Generic
import typing
-State = TypeVar('State')
+from polymatrix.statemonad.mixins.statemixin import StateCacheMixin
+
+State = TypeVar('State', bound=StateCacheMixin)
U = TypeVar('U')
V = TypeVar('V')
@@ -17,13 +19,6 @@ class StateMonadMixin(
def apply_func(self) -> typing.Callable[[State], tuple[State, U]]:
...
- # def init(func: Callable[[State], Tuple[State, U]]) -> 'StateMonadMixin[U, State]':
- # class StateMonadImpl(StateMonadMixin):
- # def apply(self, state: State) -> Tuple[State, U]:
- # return func(state)
-
- # return StateMonadImpl()
-
def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]':
def internal_map(state: State) -> Tuple[State, U]:
@@ -48,6 +43,21 @@ class StateMonadMixin(
return dataclasses.replace(self, apply_func=internal_map)
- # @abc.abstractmethod
+ def cache(self) -> 'StateMonadMixin':
+ def internal_map(state: State) -> Tuple[State, V]:
+ if self in state.cache:
+ return state, state.cache[self]
+
+ state, val = self.apply(state)
+
+ state = dataclasses.replace(
+ state,
+ cache=state.cache | {self: val},
+ )
+
+ return state, val
+
+ return dataclasses.replace(self, apply_func=internal_map)
+
def apply(self, state: State) -> Tuple[State, U]:
return self.apply_func(state)