diff options
30 files changed, 788 insertions, 301 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 139e391..20b05a9 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -1,4 +1,7 @@ +import collections +import dataclasses import itertools +import typing import numpy as np import scipy.sparse # import polymatrix.statemonad @@ -6,6 +9,7 @@ import scipy.sparse from polymatrix.expression.expression import Expression from polymatrix.expression.expressionstate import ExpressionState from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr +from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr from polymatrix.expression.init.initexpression import init_expression from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr @@ -27,6 +31,14 @@ def from_( ) +def from_polymatrix( + polymatrix: PolyMatrix, +): + return init_expression( + init_from_terms_expr(polymatrix) + ) + + def accumulate( expr, func, @@ -34,7 +46,6 @@ def accumulate( ): def lifted_func(acc, polymat: PolyMatrix): - # # print(f'{terms=}') # print(f'{terms=}') @@ -62,103 +73,228 @@ def v_stack( ) -def kkt( - cost: Expression, +def h_stack( + expressions: tuple[Expression], +): + return init_expression( + init_v_stack_expr(tuple(expr.T for expr in expressions)) + ).T + + +def block_diag( + expressions: tuple[Expression], +): + return init_expression( + init_block_diag_expr(expressions) + ) + + +# def kkt( +# cost: Expression, +# variables: Expression, +# equality: Expression = None, +# inequality: Expression = None, +# ): +# return init_expression( +# init_kkt_expr( +# cost=cost, +# equality=equality, +# variables=variables, +# inequality=inequality, +# ) +# ) + +def kkt_equality( variables: Expression, equality: Expression = None, +): + + self_variables = variables + self_equality = equality + + def func(state: ExpressionState): + + state, equality = self_equality.apply(state=state) + + state, equality_der = self_equality.diff( + self_variables, + introduce_derivatives=True, + ).apply(state) + + def acc_nu_variables(acc, v): + state, nu_variables = acc + + nu_variable = state.n_param + state = state.register(n_param=1) + + return state, nu_variables + [nu_variable] + + *_, (state, nu_variables) = tuple(itertools.accumulate( + range(equality.shape[0]), + acc_nu_variables, + initial=(state, []), + )) + + terms = {} + + for row in range(equality_der.shape[1]): + + monomial_terms = collections.defaultdict(float) + + for eq_idx, nu_variable in enumerate(nu_variables): + + try: + underlying_terms = equality_der.get_poly(eq_idx, row) + except KeyError: + continue + + for monomial, value in underlying_terms.items(): + new_monomial = monomial + (nu_variable,) + + monomial_terms[new_monomial] += value + + terms[row, 0] = dict(monomial_terms) + + cost_expr = init_expression(init_from_terms_expr( + terms=terms, + shape=(equality_der.shape[1], 1), + )) + + nu_terms = {} + + for eq_idx, nu_variable in enumerate(nu_variables): + nu_terms[eq_idx, 0] = {(nu_variable,): 1} + + nu_expr = init_expression(init_from_terms_expr( + terms=nu_terms, + shape=(len(nu_variables), 1), + )) + + return state, (nu_expr, cost_expr) + + return init_state_monad(func) + +def kkt_inequality( + variables: Expression, inequality: Expression = None, ): - return init_expression( - init_kkt_expr( - cost=cost, - equality=equality, - variables=variables, - inequality=inequality, - ) - ) - # self_cost = cost - # self_variables = variables - # self_equality = equality + self_variables = variables + self_inequality = inequality - # def func(state: ExpressionState): - # state, cost = self_cost.apply(state=state) + def func(state: ExpressionState): - # assert cost.shape[1] == 1 + state, inequality = self_inequality.apply(state=state) - # if self_equality is not None: + state, inequality_der = self_inequality.diff( + self_variables, + introduce_derivatives=True, + ).apply(state) - # state, equality = self_equality.apply(state=state) + def acc_lambda_variables(acc, v): + state, lambda_variables = acc - # state, equality_der = self_equality.diff( - # self_variables, - # introduce_derivatives=True, - # ).apply(state) + lambda_variable = state.n_param + state = state.register(n_param=3) + + return state, lambda_variables + [lambda_variable] - # assert cost.shape[0] == equality_der.shape[1] + *_, (state, lambda_variables) = tuple(itertools.accumulate( + range(inequality.shape[0]), + acc_lambda_variables, + initial=(state, []), + )) - # def acc_nu_variables(acc, v): - # state, nu_variables = acc + terms = {} - # nu_variable = state.n_param - # state = state.register(n_param=1) - - # return state, nu_variables + [nu_variable] + for row in range(inequality_der.shape[1]): - # *_, (state, nu_variables) = tuple(itertools.accumulate( - # range(equality.shape[0]), - # acc_nu_variables, - # initial=(state, []), - # )) + monomial_terms = collections.defaultdict(float) - # else: - # nu_variables = tuple() + for inequality_idx, lambda_variable in enumerate(lambda_variables): - # idx_start = 0 + try: + underlying_terms = inequality_der.get_poly(inequality_idx, row) + except KeyError: + continue - # terms = {} + for monomial, value in underlying_terms.items(): + new_monomial = monomial + (lambda_variable,) + + monomial_terms[new_monomial] += value + + terms[row, 0] = dict(monomial_terms) + + cost_expr = init_expression(init_from_terms_expr( + terms=terms, + shape=(inequality_der.shape[1], 1), + )) + + # inequality, dual feasibility, complementary slackness + # ----------------------------------------------------- + + inequality_terms = {} + feasibility_terms = {} + complementary_terms = {} + + for inequality_idx, lambda_variable in enumerate(lambda_variables): + r_lambda = lambda_variable + 1 + r_inequality = lambda_variable + 2 - # for row in range(cost.shape[0]): - # try: - # monomial_terms = cost.get_poly(row, 0) - # except KeyError: - # monomial_terms = {} + try: + underlying_terms = inequality.get_poly(inequality_idx, 0) + except KeyError: + continue + + # f(x) <= -0.01 + inequality_terms[inequality_idx, 0] = underlying_terms | {(r_inequality, r_inequality): 1} - # for eq_idx, nu_variable in enumerate(nu_variables): + # dual feasibility, lambda >= 0 + feasibility_terms[inequality_idx, 0] = {(lambda_variable,): 1, (r_lambda, r_lambda): -1} - # try: - # underlying_terms = equality_der.get_poly(eq_idx, row) - # except KeyError: - # continue + # complementary slackness + complementary_terms[inequality_idx, 0] = {(r_lambda, r_inequality): 1} - # for monomial, value in underlying_terms.items(): - # new_monomial = monomial + (nu_variable,) + inequality_expr = init_expression(init_from_terms_expr( + terms=inequality_terms, + shape=(len(lambda_variables), 1), + )) - # if new_monomial not in monomial_terms: - # monomial_terms[new_monomial] = 0 + feasibility_expr = init_expression(init_from_terms_expr( + terms=feasibility_terms, + shape=(len(lambda_variables), 1), + )) - # monomial_terms[new_monomial] += value + complementary_expr = init_expression(init_from_terms_expr( + terms=complementary_terms, + shape=(len(lambda_variables), 1), + )) - # terms[idx_start, 0] = monomial_terms - # idx_start += 1 + # lambda expression + # ----------------- - # cost_expr = init_expression(init_from_terms_expr( - # terms=terms, - # shape=(idx_start, 1), - # )) + terms = {} + for inequality_idx, lambda_variable in enumerate(lambda_variables): + terms[inequality_idx, 0] = {(lambda_variable,): 1} - # terms = {} - # for eq_idx, nu_variable in enumerate(nu_variables): - # terms[eq_idx, 0] = {(nu_variable,): 1} + lambda_expr = init_expression(init_from_terms_expr( + terms=terms, + shape=(len(lambda_variables), 1), + )) + + return state, (lambda_expr, cost_expr, inequality_expr, feasibility_expr, complementary_expr) + + return init_state_monad(func) - # nu_expr = init_expression(init_from_terms_expr( - # terms=terms, - # shape=(len(nu_variables), 1), - # )) +# def to_polymatrix( +# expr: Expression, +# ): +# def func(state: ExpressionState): +# state, polymatrix = expr.apply(state) - # return state, (cost_expr, nu_expr) +# return state, polymatrix - # return StateMonadMixin.init(func) +# return init_state_monad(func) # def to_linear_matrix( # expr: Expression, @@ -181,13 +317,62 @@ def kkt( # return StateMonad.init(func) +@dataclasses.dataclass +class MatrixEquations: + matrix_equations: tuple[tuple[np.ndarray, ...], ...] + auxillary_matrix_equations: typing.Optional[tuple[np.ndarray, ...]] + variable_index: tuple[int, ...] + state: ExpressionState + + def merge_matrix_equations(self): + def gen_matrices(index: int): + for equations in self.matrix_equations: + if index < len(equations): + yield equations[index] + + if index < len(self.auxillary_matrix_equations): + yield self.auxillary_matrix_equations[index] + + matrix_1 = np.vstack(tuple(gen_matrices(0))) + matrix_2 = np.vstack(tuple(gen_matrices(1))) + matrix_3 = scipy.sparse.vstack(tuple(gen_matrices(2))) + + return (matrix_1, matrix_2, matrix_3) + + def get_value(self, variable, value): + if isinstance(variable, Expression): + variable = variable.underlying + + offset = self.state.offset_dict[variable] + offset_idx = list(self.variable_index.index(idx) for idx in range(*offset)) + return value[offset_idx] + + def to_matrix_equations( - expr: Expression, -) -> StateMonadMixin[ExpressionState, tuple[tuple[np.ndarray, ...], tuple[int, ...]]]: + expr: tuple[Expression], +) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: + + if isinstance(expr, Expression): + expr = (expr,) + def func(state: ExpressionState): - state, underlying = expr.apply(state) - assert underlying.shape[1] == 1 + def acc_underlying(acc, v): + state, underlying_list = acc + + state, underlying = v.apply(state) + + assert underlying.shape[1] == 1, f'{underlying.shape[1]} is not 1' + + return state, underlying_list + (underlying,) + + *_, (state, underlying_list) = tuple(itertools.accumulate( + expr, + acc_underlying, + initial=(state, tuple()), + )) + + # state, underlying = expr.apply(state) def gen_used_variables(): def gen_used_auxillary_variables(considered): @@ -199,73 +384,126 @@ def to_matrix_equations( if variable not in considered and variable in state.auxillary_equations: yield from gen_used_auxillary_variables(considered + (variable,)) - for row in range(underlying.shape[0]): - for col in range(underlying.shape[1]): + for underlying in underlying_list: + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue - for monomial in underlying_terms.keys(): - for variable in monomial: - yield variable + for monomial in underlying_terms.keys(): + for variable in monomial: + yield variable - if variable in state.auxillary_equations: - yield from gen_used_auxillary_variables((variable,)) + if variable in state.auxillary_equations: + yield from gen_used_auxillary_variables((variable,)) used_variables = set(gen_used_variables()) - ordered_variable_index = tuple(sorted(used_variables)) variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} n_param = len(ordered_variable_index) - A = np.zeros((n_param, 1), dtype=np.float32) - B = np.zeros((n_param, n_param), dtype=np.float32) - C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32) + def gen_underlying_matrices(): + for underlying in underlying_list: + n_row = underlying.shape[0] + + A = np.zeros((n_row, 1), dtype=np.double) + B = np.zeros((n_row, n_param), dtype=np.double) + C = scipy.sparse.dok_array((n_row, n_param**2), dtype=np.double) + + # def populate_matrices(monomial_terms, row): + # for monomial, value in monomial_terms.items(): + # new_monomial = tuple(variable_index_map[var] for var in monomial) + # col = monomial_to_index(n_param, new_monomial) + + # match len(new_monomial): + # case 0: + # A[row, col] = value + # case 1: + # B[row, col] = value + # case 2: + # C[row, col] = value + # case _: + # raise Exception(f'illegal case {new_monomial=}') + + for row in range(underlying.shape[0]): + try: + underlying_terms = underlying.get_poly(row, 0) + except KeyError: + continue - def populate_matrices(monomial_terms, row): - for monomial, value in monomial_terms.items(): - new_monomial = tuple(variable_index_map[var] for var in monomial) - col = monomial_to_index(n_param, new_monomial) + # populate_matrices( + # monomial_terms=underlying_terms, + # row=row, + # ) + + for monomial, value in underlying_terms.items(): + new_monomial = tuple(variable_index_map[var] for var in monomial) + col = monomial_to_index(n_param, new_monomial) + + match len(new_monomial): + case 0: + A[row, col] = value + case 1: + B[row, col] = value + case 2: + C[row, col] = value + case _: + raise Exception(f'illegal case {new_monomial=}') + + yield A, B, C - match len(new_monomial): - case 0: - A[row, col] = value - case 1: - B[row, col] = value - case 2: - C[row, col] = value - case _: - raise Exception(f'illegal case {new_monomial=}') + underlying_matrices = tuple(gen_underlying_matrices()) - for row in range(underlying.shape[0]): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: - continue + # current_row = underlying.shape[0] + + def gen_auxillary_equations(): + for key, monomial_terms in state.auxillary_equations.items(): + if key in ordered_variable_index: + yield key, monomial_terms - populate_matrices( - monomial_terms=underlying_terms, - row=row, - ) + auxillary_equations = tuple(gen_auxillary_equations()) - current_row = underlying.shape[0] + n_row = len(auxillary_equations) - for key, monomial_terms in state.auxillary_equations.items(): - if key in ordered_variable_index: - populate_matrices( - monomial_terms=monomial_terms, - row=current_row, - ) - current_row += 1 + if n_row == 0: + auxillary_matrix_equations = None - # assert current_row == n_param, f'{current_row} is not {n_param}' + else: + A = np.zeros((n_row, 1), dtype=np.double) + B = np.zeros((n_row, n_param), dtype=np.double) + C = scipy.sparse.dok_array((n_row, n_param**2), dtype=np.double) - return state, ((A, B, C), ordered_variable_index) + for row, (key, monomial_terms) in enumerate(auxillary_equations): + for monomial, value in monomial_terms.items(): + new_monomial = tuple(variable_index_map[var] for var in monomial) + col = monomial_to_index(n_param, new_monomial) - return StateMonadMixin.init(func) + match len(new_monomial): + case 0: + A[row, col] = value + case 1: + B[row, col] = value + case 2: + C[row, col] = value + case _: + raise Exception(f'illegal case {new_monomial=}') + + auxillary_matrix_equations = (A, B, C) + + result = MatrixEquations( + matrix_equations=underlying_matrices, + auxillary_matrix_equations=auxillary_matrix_equations, + variable_index=ordered_variable_index, + state=state, + ) + + return state, result + + return init_state_monad(func) def to_constant_matrix( expr: Expression, diff --git a/polymatrix/expression/blockdiagexpr.py b/polymatrix/expression/blockdiagexpr.py new file mode 100644 index 0000000..e3acee5 --- /dev/null +++ b/polymatrix/expression/blockdiagexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin + +class BlockDiagExpr(BlockDiagExprMixin): + pass diff --git a/polymatrix/expression/cacheexpr.py b/polymatrix/expression/cacheexpr.py new file mode 100644 index 0000000..5ae4052 --- /dev/null +++ b/polymatrix/expression/cacheexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin + +class CacheExpr(CacheExprMixin): + pass diff --git a/polymatrix/expression/impl/blockdiagexprimpl.py b/polymatrix/expression/impl/blockdiagexprimpl.py new file mode 100644 index 0000000..a2707d8 --- /dev/null +++ b/polymatrix/expression/impl/blockdiagexprimpl.py @@ -0,0 +1,7 @@ +import dataclass_abc +from polymatrix.expression.blockdiagexpr import BlockDiagExpr + + +@dataclass_abc.dataclass_abc(frozen=True) +class BlockDiagExprImpl(BlockDiagExpr): + underlying: tuple diff --git a/polymatrix/expression/impl/cacheexprimpl.py b/polymatrix/expression/impl/cacheexprimpl.py new file mode 100644 index 0000000..a6a74a1 --- /dev/null +++ b/polymatrix/expression/impl/cacheexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.cacheexpr import CacheExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class CacheExprImpl(CacheExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/expressionstateimpl.py b/polymatrix/expression/impl/expressionstateimpl.py index 7513970..5459eb7 100644 --- a/polymatrix/expression/impl/expressionstateimpl.py +++ b/polymatrix/expression/impl/expressionstateimpl.py @@ -1,3 +1,4 @@ +from functools import cached_property import dataclass_abc from polymatrix.expression.expressionstate import ExpressionState @@ -8,4 +9,4 @@ class ExpressionStateImpl(ExpressionState): n_param: int offset_dict: dict auxillary_equations: dict[int, dict[tuple[int], float]] - cached_polymatrix: dict + cache: dict diff --git a/polymatrix/expression/impl/reshapeexprimpl.py b/polymatrix/expression/impl/reshapeexprimpl.py new file mode 100644 index 0000000..7c16f16 --- /dev/null +++ b/polymatrix/expression/impl/reshapeexprimpl.py @@ -0,0 +1,9 @@ +import dataclass_abc +from polymatrix.expression.reshapeexpr import ReshapeExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class ReshapeExprImpl(ReshapeExpr): + underlying: ExpressionBaseMixin + new_shape: tuple diff --git a/polymatrix/expression/init/initblockdiagexpr.py b/polymatrix/expression/init/initblockdiagexpr.py new file mode 100644 index 0000000..930385f --- /dev/null +++ b/polymatrix/expression/init/initblockdiagexpr.py @@ -0,0 +1,9 @@ +from polymatrix.expression.impl.blockdiagexprimpl import BlockDiagExprImpl + + +def init_block_diag_expr( + underlying: tuple, +): + return BlockDiagExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/init/initcacheexpr.py b/polymatrix/expression/init/initcacheexpr.py new file mode 100644 index 0000000..977e033 --- /dev/null +++ b/polymatrix/expression/init/initcacheexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.cacheexprimpl import CacheExprImpl + + +def init_cache_expr( + underlying: ExpressionBaseMixin, +): + return CacheExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py index 7e8a6fe..be2e4f1 100644 --- a/polymatrix/expression/init/initexpressionstate.py +++ b/polymatrix/expression/init/initexpressionstate.py @@ -15,5 +15,5 @@ def init_expression_state( n_param=n_param, offset_dict=offset_dict, auxillary_equations={}, - cached_polymatrix={}, + cache={}, ) diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py index 80d5198..1abac12 100644 --- a/polymatrix/expression/init/initfromtermsexpr.py +++ b/polymatrix/expression/init/initfromtermsexpr.py @@ -1,10 +1,19 @@ +import typing from polymatrix.expression.impl.fromtermsexprimpl import FromTermsExprImpl +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin def init_from_terms_expr( - terms: tuple, - shape: tuple[int, int] + terms: typing.Union[tuple, PolyMatrixMixin], + shape: tuple[int, int] = None, ): + if isinstance(terms, PolyMatrixMixin): + shape = terms.shape + terms = terms.get_terms() + + else: + assert shape is not None + if isinstance(terms, dict): terms = tuple((key, tuple(value.items())) for key, value in terms.items()) diff --git a/polymatrix/expression/init/initreshapeexpr.py b/polymatrix/expression/init/initreshapeexpr.py new file mode 100644 index 0000000..f95fb00 --- /dev/null +++ b/polymatrix/expression/init/initreshapeexpr.py @@ -0,0 +1,12 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.reshapeexprimpl import ReshapeExprImpl + + +def init_reshape_expr( + underlying: ExpressionBaseMixin, + new_shape: tuple, +): + return ReshapeExprImpl( + underlying=underlying, + new_shape=new_shape, +) diff --git a/polymatrix/expression/mixins/addauxequationsexprmixin.py b/polymatrix/expression/mixins/addauxequationsexprmixin.py new file mode 100644 index 0000000..1dae4f7 --- /dev/null +++ b/polymatrix/expression/mixins/addauxequationsexprmixin.py @@ -0,0 +1,58 @@ + +import abc +import dataclasses +import itertools +import dataclass_abc +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +# is this really needed? +class AddAuxEquationsExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def _apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape[1] == 1 + + @dataclass_abc.dataclass_abc(frozen=True) + class AddAuxEquationsPolyMatrix(PolyMatrixMixin): + underlying: tuple[PolyMatrixMixin] + shape: tuple[int, int] + n_row: int + auxillary_equations: tuple[dict[tuple[int], float]] + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + if row < self.n_row: + return self.underlying.get_poly(row, col) + + elif row < self.shape[0]: + return self.auxillary_equations[row - self.n_row] + + else: + raise Exception(f'row {row} is out of bounds') + + auxillary_equations = tuple(state.auxillary_equations.values()) + + polymat = AddAuxEquationsPolyMatrix( + underlying=underlying, + shape=(underlying.shape[0] + len(auxillary_equations), 1), + n_row=underlying.shape[0], + auxillary_equations=auxillary_equations, + ) + + state = dataclasses.replace(state, auxillary_equations={}) + + return state, polymat diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py new file mode 100644 index 0000000..6bdbce1 --- /dev/null +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -0,0 +1,67 @@ + +import abc +import itertools +import dataclass_abc +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +class BlockDiagExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> tuple[ExpressionBaseMixin, ...]: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def _apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + + all_underlying = [] + for expr in self.underlying: + state, polymat = expr.apply(state=state) + all_underlying.append(polymat) + + # assert all(underlying.shape[0] == underlying.shape[1] for underlying in all_underlying) + + @dataclass_abc.dataclass_abc(frozen=True) + class BlockDiagPolyMatrix(PolyMatrixMixin): + all_underlying: tuple[PolyMatrixMixin] + underlying_row_col_range: tuple[tuple[int, int], ...] + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + for polymat, ((row_start, col_start), (row_end, col_end)) in zip(self.all_underlying, self.underlying_row_col_range): + + if row_start <= row < row_end: + if col_start <= col < col_end: + return polymat.get_poly( + row=row-row_start, + col=col-col_start, + ) + + else: + raise KeyError() + + raise Exception(f'row {row} is out of bounds') + + underlying_row_col_range = tuple(itertools.pairwise( + itertools.accumulate( + (expr.shape for expr in all_underlying), + lambda acc, v: tuple(v1+v2 for v1, v2 in zip(acc, v)), + initial=(0, 0)) + )) + + shape = underlying_row_col_range[-1][1] + + polymat = BlockDiagPolyMatrix( + all_underlying=all_underlying, + shape=shape, + underlying_row_col_range=underlying_row_col_range, + ) + + return state, polymat diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py new file mode 100644 index 0000000..8779ba1 --- /dev/null +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -0,0 +1,40 @@ + +import abc +import dataclasses + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + + +class CacheExprMixin(ExpressionBaseMixin): + @property + @abc.abstractclassmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def _apply( + self, + state: ExpressionStateMixin, + ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + + if self in state.cache: + return state, state.cache[self] + + state, underlying = self.underlying.apply(state) + + cached_terms = dict(underlying.get_terms()) + + poly_matrix = init_poly_matrix( + terms=cached_terms, + shape=underlying.shape, + ) + + state = dataclasses.replace( + state, + cache=state.cache | {self: poly_matrix}, + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py index 5b69d3f..5436ed6 100644 --- a/polymatrix/expression/mixins/determinantexprmixin.py +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -26,8 +26,8 @@ class DeterminantExprMixin(ExpressionBaseMixin): self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - if self in state.cached_polymatrix: - return state, state.cached_polymatrix[self] + if self in state.cache: + return state, state.cache[self] state, underlying = self.underlying.apply(state=state) @@ -110,7 +110,7 @@ class DeterminantExprMixin(ExpressionBaseMixin): state = dataclasses.replace( state, auxillary_equations=state.auxillary_equations | auxillary_equations, - cached_polymatrix=state.cached_polymatrix | {self: poly_matrix}, + cache=state.cache | {self: poly_matrix}, ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index 4b25ec6..cfb6bea 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -30,8 +30,8 @@ class DivisionExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - if self in state.cached_polymatrix: - return state, state.cached_polymatrix[self] + if self in state.cache: + return state, state.cache[self] state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) @@ -76,7 +76,7 @@ class DivisionExprMixin(ExpressionBaseMixin): state = dataclasses.replace( state, auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms}, - cached_polymatrix=state.cached_polymatrix | {self: poly_matrix}, + cached_polymatrix=state.cache | {self: poly_matrix}, ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 4d78b26..d4fdbea 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -73,13 +73,15 @@ class EvalExprMixin(ExpressionBaseMixin): initial=(tuple(), value), )) - # print(f'{new_monomial=}') - if new_monomial not in terms_row_col: terms_row_col[new_monomial] = 0 terms_row_col[new_monomial] += new_value + # delete zero entries + if terms_row_col[new_monomial] == 0: + del terms_row_col[new_monomial] + terms[row, col] = terms_row_col poly_matrix = init_poly_matrix( diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index baf605d..3825ed1 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -1,11 +1,9 @@ import abc from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin -from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin class ExpressionBaseMixin( - # StateMonad[ExpressionStateMixin, PolyMatrixMixin], abc.ABC ): @@ -13,7 +11,6 @@ class ExpressionBaseMixin( def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: ... - @abc.abstractmethod def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}' diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index aa9ca17..4cb4a50 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -5,6 +5,7 @@ import numpy as np from sympy import re from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr from polymatrix.expression.init.initadditionexpr import init_addition_expr +from polymatrix.expression.init.initcacheexpr import init_cache_expr from polymatrix.expression.init.initderivativeexpr import init_derivative_expr from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr from polymatrix.expression.init.initdivisionexpr import init_division_expr @@ -18,6 +19,7 @@ from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize_terms_expr from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr +from polymatrix.expression.init.initreshapeexpr import init_reshape_expr from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr from polymatrix.expression.init.inittransposeexpr import init_transpose_expr @@ -37,7 +39,7 @@ class ExpressionMixin( ... # overwrites abstract method of `PolyMatrixExprBaseMixin` - def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + def _apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: return self.underlying.apply(state) # # overwrites abstract method of `PolyMatrixExprBaseMixin` @@ -45,11 +47,6 @@ class ExpressionMixin( # def degrees(self) -> set[int]: # return self.underlying.degrees - # # overwrites abstract method of `PolyMatrixExprBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.underlying.shape - # def __iter__(self): # for row in range(self.shape[0]): # yield self[row, 0] @@ -93,7 +90,11 @@ class ExpressionMixin( def __mul__(self, other) -> 'ExpressionMixin': # assert isinstance(other, float) - right = init_from_array_expr(other) + match other: + case ExpressionBaseMixin(): + right = other.underlying + case _: + right = init_from_array_expr(other) return dataclasses.replace( self, @@ -106,6 +107,9 @@ class ExpressionMixin( def __rmul__(self, other): return self * other + def __neg__(self): + return self * (-1) + def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin': match other: case ExpressionBaseMixin(): @@ -153,6 +157,14 @@ class ExpressionMixin( ), ) + def cache(self) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_cache_expr( + underlying=self.underlying, + ), + ) + def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -163,6 +175,15 @@ class ExpressionMixin( ), ) + def reshape(self, n: int, m: int) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_reshape_expr( + underlying=self.underlying, + new_shape=(n, m), + ), + ) + def rep_mat(self, n: int, m: int) -> 'ExpressionMixin': return dataclasses.replace( self, diff --git a/polymatrix/expression/mixins/expressionstatemixin.py b/polymatrix/expression/mixins/expressionstatemixin.py index 526b422..523d2b7 100644 --- a/polymatrix/expression/mixins/expressionstatemixin.py +++ b/polymatrix/expression/mixins/expressionstatemixin.py @@ -5,8 +5,12 @@ import typing import sympy +from polymatrix.statemonad.mixins.statemixin import StateCacheMixin -class ExpressionStateMixin(abc.ABC): + +class ExpressionStateMixin( + StateCacheMixin, +): @property @abc.abstractmethod @@ -27,10 +31,10 @@ class ExpressionStateMixin(abc.ABC): def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: ... - @property - @abc.abstractmethod - def cached_polymatrix(self) -> dict: - ... + # @property + # @abc.abstractmethod + # def cache(self) -> dict: + # ... def register( self, diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 557bed0..9ea4e92 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -7,6 +7,7 @@ from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices class LinearInExprMixin(ExpressionBaseMixin): @@ -28,8 +29,8 @@ class LinearInExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state=state) # todo: uncomment this - # state, variable_indices = get_variable_indices(state, self.variables) - variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) + state, variable_indices = get_variable_indices(state, self.variables) + # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) terms = {} idx_row = 0 @@ -49,6 +50,8 @@ class LinearInExprMixin(ExpressionBaseMixin): x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' + x_monomial_terms[x_monomial][p_monomial] += value for data in x_monomial_terms.values(): diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py index f20306c..255a5da 100644 --- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py +++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py @@ -28,56 +28,14 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): def variables(self) -> tuple: ... - # @property - # def shape(self) -> tuple[int, int]: - # return self.underlying.shape - - # @dataclass_abc.dataclass_abc - # class ParametrizeTermsPolyMatrix(PolyMatrixAsDictMixin): - # shape: tuple[int, int] - # terms: dict - # start_index: int - # n_param: int - - # @property - # def param(self) -> tuple[int, int]: - # outer_self = self - - # @dataclass_abc.dataclass_abc(frozen=True) - # class ParameterExpr(ExpressionBaseMixin): - # def apply( - # self, - # state: ExpressionStateMixin, - # ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - - # state, poly_matrix = outer_self.apply(state) - - # n_param = poly_matrix.n_param - # start_index = poly_matrix.start_index - - # def gen_monomials(): - # for rel_index in range(n_param): - # yield {(start_index + rel_index,): 1} - - # terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())} - - # poly_matrix = init_poly_matrix( - # terms=terms, - # shape=(n_param, 1), - # ) - - # return state, poly_matrix - - # return ParameterExpr() - # overwrites abstract method of `ExpressionBaseMixin` def _apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - if self in state.cached_polymatrix: - return state, state.cached_polymatrix[self] + if self in state.cache: + return state, state.cache[self] # if not hasattr(self, '_terms'): state, underlying = self.underlying.apply(state) @@ -124,13 +82,6 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): n_param=start_index - state.n_param, ) - # poly_matrix = ParametrizeTermsExprMixin.ParametrizeTermsPolyMatrix( - # terms=terms, - # shape=underlying.shape, - # start_index=idx_start, - # n_param=n_param, - # ) - poly_matrix = init_poly_matrix( terms=terms, shape=underlying.shape, @@ -138,7 +89,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): state = dataclasses.replace( state, - cached_polymatrix=state.cached_polymatrix | {self: poly_matrix}, + cache=state.cache | {self: poly_matrix}, ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/polymatrixmixin.py b/polymatrix/expression/mixins/polymatrixmixin.py index f83a62c..89e05cb 100644 --- a/polymatrix/expression/mixins/polymatrixmixin.py +++ b/polymatrix/expression/mixins/polymatrixmixin.py @@ -14,88 +14,19 @@ class PolyMatrixMixin(abc.ABC): def shape(self) -> tuple[int, int]: ... + def get_terms(self) -> tuple[tuple[int, int], dict[tuple[int, ...], float]]: + def gen_terms(): + for row in range(self.shape[0]): + for col in range(self.shape[1]): + try: + monomial_terms = self.get_poly(row, col) + except KeyError: + continue + + yield (row, col), monomial_terms + + return tuple(gen_terms()) + @abc.abstractclassmethod def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: ... - - # def get_equations( - # self, - # state: ExpressionStateMixin, - # ): - # assert self.shape[1] == 1 - - # def gen_used_variables(): - # def gen_used_auxillary_variables(considered): - # monomial_terms = state.auxillary_equations[considered[-1]] - # for monomial in monomial_terms.keys(): - # for variable in monomial: - # yield variable - - # if variable not in considered and variable in state.auxillary_equations: - # yield from gen_used_auxillary_variables(considered + (variable,)) - - # for row in range(self.shape[0]): - # for col in range(self.shape[1]): - - # try: - # underlying_terms = self.get_poly(row, col) - # except KeyError: - # continue - - # for monomial in underlying_terms.keys(): - # for variable in monomial: - # yield variable - - # if variable in state.auxillary_equations: - # yield from gen_used_auxillary_variables((variable,)) - - # used_variables = set(gen_used_variables()) - - # ordered_variable_index = tuple(sorted(used_variables)) - # variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} - - # n_param = len(ordered_variable_index) - - # A = np.zeros((n_param, 1), dtype=np.float32) - # B = np.zeros((n_param, n_param), dtype=np.float32) - # C = scipy.sparse.dok_array((n_param, n_param**2), dtype=np.float32) - - # def populate_matrices(monomial_terms, row): - # for monomial, value in monomial_terms.items(): - # new_monomial = tuple(variable_index_map[var] for var in monomial) - # col = monomial_to_index(n_param, new_monomial) - - # match len(new_monomial): - # case 0: - # A[row, col] = value - # case 1: - # B[row, col] = value - # case 2: - # C[row, col] = value - # case _: - # raise Exception(f'illegal case {new_monomial=}') - - # for row in range(self.shape[0]): - # try: - # underlying_terms = self.get_poly(row, 0) - # except KeyError: - # continue - - # populate_matrices( - # monomial_terms=underlying_terms, - # row=row, - # ) - - # current_row = self.shape[0] - - # for key, monomial_terms in state.auxillary_equations.items(): - # if key in ordered_variable_index: - # populate_matrices( - # monomial_terms=monomial_terms, - # row=current_row, - # ) - # current_row += 1 - - # # assert current_row == n_param, f'{current_row} is not {n_param}' - - # return (A, B, C), ordered_variable_index diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index ce6b5c2..1fff79b 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -5,6 +5,7 @@ from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices class QuadraticInExprMixin(ExpressionBaseMixin): @@ -18,11 +19,6 @@ class QuadraticInExprMixin(ExpressionBaseMixin): def variables(self) -> tuple: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return 2*(len(self.variables),) - # overwrites abstract method of `ExpressionBaseMixin` def _apply( self, @@ -32,7 +28,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): assert underlying.shape == (1, 1) - variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) + state, variable_indices = get_variable_indices(state, self.variables) terms = {} @@ -46,7 +42,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): for monomial, value in underlying_terms.items(): - x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + x_monomial = tuple(variable_indices.index(var_idx) for var_idx in monomial if var_idx in variable_indices) p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) assert len(x_monomial) == 2, f'{x_monomial} should be of length 2' @@ -63,7 +59,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): monomial_terms[p_monomial] = 0 monomial_terms[p_monomial] += value - + poly_matrix = init_poly_matrix( terms=terms, shape=2*(len(self.variables),), diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py new file mode 100644 index 0000000..886f074 --- /dev/null +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -0,0 +1,75 @@ +import abc +import functools +import operator +import dataclass_abc +import numpy as np + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + +class ReshapeExprMixin(ExpressionBaseMixin): + @property + @abc.abstractclassmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractclassmethod + def new_shape(self) -> tuple[int, int]: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def _apply( + self, + state: ExpressionStateMixin, + ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + + state, underlying = self.underlying.apply(state) + + @dataclass_abc.dataclass_abc(frozen=True) + class ReshapePolyMatrix(PolyMatrixMixin): + underlying: PolyMatrixMixin + shape: tuple[int, int] + underlying_shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + index = row + self.shape[0] * col + + underlying_col = int(index / self.underlying_shape[0]) + underlying_row = index - underlying_col * self.underlying_shape[0] + + # print(f'{row=}, {col=}') + # print(f'{underlying_row=}, {underlying_col=}') + + return self.underlying.get_poly(underlying_row, underlying_col) + + # replace '-1' by the remaining number of elements + if -1 in self.new_shape: + n_total = underlying.shape[0] * underlying.shape[1] + + remaining_shape = tuple(e for e in self.new_shape if e != -1) + + assert len(remaining_shape) + 1 == len(self.new_shape) + + n_used = functools.reduce(operator.mul, remaining_shape) + + n_remaining = int(n_total / n_used) + + def gen_shape(): + for e in self.new_shape: + if e == -1: + yield n_remaining + else: + yield e + + new_shape = tuple(gen_shape()) + + else: + new_shape = self.new_shape + + return state, ReshapePolyMatrix( + underlying=underlying, + shape=new_shape, + underlying_shape=underlying.shape, + )
\ No newline at end of file diff --git a/polymatrix/expression/reshapeexpr.py b/polymatrix/expression/reshapeexpr.py new file mode 100644 index 0000000..01ea7dd --- /dev/null +++ b/polymatrix/expression/reshapeexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin + +class ReshapeExpr(ReshapeExprMixin): + pass diff --git a/polymatrix/statemonad/__init__.py b/polymatrix/statemonad/__init__.py index 33aa86a..bb280e0 100644 --- a/polymatrix/statemonad/__init__.py +++ b/polymatrix/statemonad/__init__.py @@ -2,6 +2,13 @@ from polymatrix.statemonad.init.initstatemonad import init_state_monad from polymatrix.statemonad.statemonad import StateMonad +def from_(val): + def func(state): + return state, val + + return init_state_monad(func) + + def zip(monads: tuple[StateMonad]): def zip_func(state): @@ -14,3 +21,5 @@ def zip(monads: tuple[StateMonad]): return state, values return init_state_monad(zip_func) + + diff --git a/polymatrix/statemonad/mixins/statemixin.py b/polymatrix/statemonad/mixins/statemixin.py new file mode 100644 index 0000000..6e2855b --- /dev/null +++ b/polymatrix/statemonad/mixins/statemixin.py @@ -0,0 +1,8 @@ +import abc + + +class StateCacheMixin(abc.ABC): + @property + @abc.abstractmethod + def cache(self) -> dict: + ... diff --git a/polymatrix/statemonad/mixins/statemonadmixin.py b/polymatrix/statemonad/mixins/statemonadmixin.py index 39b6576..c367708 100644 --- a/polymatrix/statemonad/mixins/statemonadmixin.py +++ b/polymatrix/statemonad/mixins/statemonadmixin.py @@ -3,7 +3,9 @@ import dataclasses from typing import Callable, Tuple, TypeVar, Generic import typing -State = TypeVar('State') +from polymatrix.statemonad.mixins.statemixin import StateCacheMixin + +State = TypeVar('State', bound=StateCacheMixin) U = TypeVar('U') V = TypeVar('V') @@ -17,13 +19,6 @@ class StateMonadMixin( def apply_func(self) -> typing.Callable[[State], tuple[State, U]]: ... - # def init(func: Callable[[State], Tuple[State, U]]) -> 'StateMonadMixin[U, State]': - # class StateMonadImpl(StateMonadMixin): - # def apply(self, state: State) -> Tuple[State, U]: - # return func(state) - - # return StateMonadImpl() - def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]': def internal_map(state: State) -> Tuple[State, U]: @@ -48,6 +43,21 @@ class StateMonadMixin( return dataclasses.replace(self, apply_func=internal_map) - # @abc.abstractmethod + def cache(self) -> 'StateMonadMixin': + def internal_map(state: State) -> Tuple[State, V]: + if self in state.cache: + return state, state.cache[self] + + state, val = self.apply(state) + + state = dataclasses.replace( + state, + cache=state.cache | {self: val}, + ) + + return state, val + + return dataclasses.replace(self, apply_func=internal_map) + def apply(self, state: State) -> Tuple[State, U]: return self.apply_func(state) |