diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-06-13 16:06:25 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-06-13 16:06:25 +0200 |
commit | 68eb5f9e15be57a32317810237061b16c96a3271 (patch) | |
tree | e02cd2160964ecc7f92ee2a35e8342bbe1c8de6f | |
parent | introduce state monad and functions to go along with it (diff) | |
download | polymatrix-68eb5f9e15be57a32317810237061b16c96a3271.tar.gz polymatrix-68eb5f9e15be57a32317810237061b16c96a3271.zip |
add eye, sum and symmetric operation
65 files changed, 1042 insertions, 206 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 20b05a9..85def18 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -11,10 +11,9 @@ from polymatrix.expression.expressionstate import ExpressionState from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr from polymatrix.expression.init.initexpression import init_expression +from polymatrix.expression.init.initeyeexpr import init_eye_expr from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr -from polymatrix.expression.init.initkktexpr import init_kkt_expr -from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr from polymatrix.expression.init.initvstackexpr import init_v_stack_expr from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.utils.getvariableindices import get_variable_indices @@ -89,6 +88,14 @@ def block_diag( ) +def eye( + variable: tuple[Expression], +): + return init_expression( + init_eye_expr(variable=variable) + ) + + # def kkt( # cost: Expression, # variables: Expression, @@ -317,6 +324,17 @@ def kkt_inequality( # return StateMonad.init(func) +def shape( + expr: Expression, +) -> StateMonadMixin[ExpressionState, tuple[int, ...]]: + def func(state: ExpressionState): + state, polymatrix = expr.apply(state) + + return state, polymatrix.shape + + return init_state_monad(func) + + @dataclasses.dataclass class MatrixEquations: matrix_equations: tuple[tuple[np.ndarray, ...], ...] @@ -340,12 +358,16 @@ class MatrixEquations: return (matrix_1, matrix_2, matrix_3) def get_value(self, variable, value): - if isinstance(variable, Expression): - variable = variable.underlying - - offset = self.state.offset_dict[variable] - offset_idx = list(self.variable_index.index(idx) for idx in range(*offset)) - return value[offset_idx] + variable_indices = get_variable_indices(self.state, variable)[1] + value_index = list(self.variable_index.index(variable_index) for variable_index in variable_indices) + return value[value_index] + + def set_value(self, variable, value): + variable_indices = get_variable_indices(self.state, variable)[1] + value_index = list(self.variable_index.index(variable_index) for variable_index in variable_indices) + vec = np.zeros(len(self.variable_index)) + vec[value_index] = value + return vec def to_matrix_equations( diff --git a/polymatrix/expression/combinationsexpr.py b/polymatrix/expression/combinationsexpr.py new file mode 100644 index 0000000..3373365 --- /dev/null +++ b/polymatrix/expression/combinationsexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin + +class CombinationsExpr(CombinationsExprMixin): + pass diff --git a/polymatrix/expression/eyeexpr.py b/polymatrix/expression/eyeexpr.py new file mode 100644 index 0000000..4e64569 --- /dev/null +++ b/polymatrix/expression/eyeexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin + +class EyeExpr(EyeExprMixin): + pass diff --git a/polymatrix/expression/impl/combinationsexprimpl.py b/polymatrix/expression/impl/combinationsexprimpl.py new file mode 100644 index 0000000..2e809a2 --- /dev/null +++ b/polymatrix/expression/impl/combinationsexprimpl.py @@ -0,0 +1,9 @@ +import dataclass_abc +from polymatrix.expression.combinationsexpr import CombinationsExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class CombinationsExprImpl(CombinationsExpr): + underlying: ExpressionBaseMixin + number: int diff --git a/polymatrix/expression/impl/eyeexprimpl.py b/polymatrix/expression/impl/eyeexprimpl.py new file mode 100644 index 0000000..9c50f8a --- /dev/null +++ b/polymatrix/expression/impl/eyeexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.eyeexpr import EyeExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class EyeExprImpl(EyeExpr): + variable: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/getitemexprimpl.py b/polymatrix/expression/impl/getitemexprimpl.py index b8972a9..578d4c7 100644 --- a/polymatrix/expression/impl/getitemexprimpl.py +++ b/polymatrix/expression/impl/getitemexprimpl.py @@ -6,4 +6,4 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @dataclass_abc.dataclass_abc(frozen=True) class GetItemExprImpl(GetItemExpr): underlying: ExpressionBaseMixin - index: tuple + index: tuple[tuple[int, ...], tuple[int, ...]] diff --git a/polymatrix/expression/impl/linearin2exprimpl.py b/polymatrix/expression/impl/linearin2exprimpl.py new file mode 100644 index 0000000..34b42bf --- /dev/null +++ b/polymatrix/expression/impl/linearin2exprimpl.py @@ -0,0 +1,9 @@ +import dataclass_abc +from polymatrix.expression.linearin2expr import LinearIn2Expr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class LinearIn2ExprImpl(LinearIn2Expr): + underlying: ExpressionBaseMixin + expression: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/squeezeexprimpl.py b/polymatrix/expression/impl/squeezeexprimpl.py new file mode 100644 index 0000000..b60af68 --- /dev/null +++ b/polymatrix/expression/impl/squeezeexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.squeezeexpr import SqueezeExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class SqueezeExprImpl(SqueezeExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/sumexprimpl.py b/polymatrix/expression/impl/sumexprimpl.py new file mode 100644 index 0000000..b148b43 --- /dev/null +++ b/polymatrix/expression/impl/sumexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.sumexpr import SumExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class SumExprImpl(SumExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/symmetricexprimpl.py b/polymatrix/expression/impl/symmetricexprimpl.py new file mode 100644 index 0000000..b15b8a0 --- /dev/null +++ b/polymatrix/expression/impl/symmetricexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.symmetricexpr import SymmetricExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class SymmetricExprImpl(SymmetricExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/impl/toconstantexprimpl.py b/polymatrix/expression/impl/toconstantexprimpl.py new file mode 100644 index 0000000..9e3d13d --- /dev/null +++ b/polymatrix/expression/impl/toconstantexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.toconstantexpr import ToConstantExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class ToConstantExprImpl(ToConstantExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/init/initcombinationsexpr.py b/polymatrix/expression/init/initcombinationsexpr.py new file mode 100644 index 0000000..7a1df2f --- /dev/null +++ b/polymatrix/expression/init/initcombinationsexpr.py @@ -0,0 +1,12 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.combinationsexprimpl import CombinationsExprImpl + + +def init_combinations_expr( + underlying: ExpressionBaseMixin, + number: int, +): + return CombinationsExprImpl( + underlying=underlying, + number=number, +) diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index 525a697..49bb0a3 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -15,7 +15,7 @@ def init_eval_expr( variables, values = tuple(zip(*variables)) elif isinstance(values, np.ndarray): - values = tuple(values) + values = tuple(values.reshape(-1)) elif not isinstance(values, tuple): values = (values,) diff --git a/polymatrix/expression/init/initeyeexpr.py b/polymatrix/expression/init/initeyeexpr.py new file mode 100644 index 0000000..e691b03 --- /dev/null +++ b/polymatrix/expression/init/initeyeexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.eyeexprimpl import EyeExprImpl + + +def init_eye_expr( + variable: ExpressionBaseMixin, +): + return EyeExprImpl( + variable=variable, +) diff --git a/polymatrix/expression/init/initgetitemexpr.py b/polymatrix/expression/init/initgetitemexpr.py index 140fa3a..5fea7a5 100644 --- a/polymatrix/expression/init/initgetitemexpr.py +++ b/polymatrix/expression/init/initgetitemexpr.py @@ -1,12 +1,22 @@ +from numpy import isin from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.getitemexprimpl import GetItemExprImpl def init_get_item_expr( underlying: ExpressionBaseMixin, - index: tuple, + index: tuple[tuple[int, ...], tuple[int, ...]], ): + + def get_hashable_slice(index): + if isinstance(index, slice): + return GetItemExprImpl.Slice(start=index.start, stop=index.stop, step=index.step) + else: + return index + + proper_index = (get_hashable_slice(index[0]), get_hashable_slice(index[1])) + return GetItemExprImpl( underlying=underlying, - index=index, + index=proper_index, ) diff --git a/polymatrix/expression/init/initlinearin2expr.py b/polymatrix/expression/init/initlinearin2expr.py new file mode 100644 index 0000000..7225510 --- /dev/null +++ b/polymatrix/expression/init/initlinearin2expr.py @@ -0,0 +1,12 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.linearin2exprimpl import LinearIn2ExprImpl + + +def init_linear_in2_expr( + underlying: ExpressionBaseMixin, + expression: ExpressionBaseMixin, +): + return LinearIn2ExprImpl( + underlying=underlying, + expression=expression, +) diff --git a/polymatrix/expression/init/initpolymatrix.py b/polymatrix/expression/init/initpolymatrix.py index e6a6cde..d0c9577 100644 --- a/polymatrix/expression/init/initpolymatrix.py +++ b/polymatrix/expression/init/initpolymatrix.py @@ -4,13 +4,9 @@ from polymatrix.expression.impl.polymatriximpl import PolyMatrixImpl def init_poly_matrix( terms: dict, shape: tuple, - aux_terms: tuple[dict[tuple[int, ...], float]] = None, ): - if aux_terms is None: - aux_terms = tuple() return PolyMatrixImpl( terms=terms, shape=shape, - # aux_terms=aux_terms, ) diff --git a/polymatrix/expression/init/initsqueezeexpr.py b/polymatrix/expression/init/initsqueezeexpr.py new file mode 100644 index 0000000..d71a9b5 --- /dev/null +++ b/polymatrix/expression/init/initsqueezeexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.squeezeexprimpl import SqueezeExprImpl + + +def init_squeeze_expr( + underlying: ExpressionBaseMixin, +): + return SqueezeExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/init/initsumexpr.py b/polymatrix/expression/init/initsumexpr.py new file mode 100644 index 0000000..606d86b --- /dev/null +++ b/polymatrix/expression/init/initsumexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.sumexprimpl import SumExprImpl + + +def init_sum_expr( + underlying: ExpressionBaseMixin, +): + return SumExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/init/initsymmetricexpr.py b/polymatrix/expression/init/initsymmetricexpr.py new file mode 100644 index 0000000..87be3ed --- /dev/null +++ b/polymatrix/expression/init/initsymmetricexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.symmetricexprimpl import SymmetricExprImpl + + +def init_symmetric_expr( + underlying: ExpressionBaseMixin, +): + return SymmetricExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/init/inittoconstantexpr.py b/polymatrix/expression/init/inittoconstantexpr.py new file mode 100644 index 0000000..ead35f8 --- /dev/null +++ b/polymatrix/expression/init/inittoconstantexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.toconstantexprimpl import ToConstantExprImpl + + +def init_to_constant_expr( + underlying: ExpressionBaseMixin, +): + return ToConstantExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/linearin2expr.py b/polymatrix/expression/linearin2expr.py new file mode 100644 index 0000000..2c40dc8 --- /dev/null +++ b/polymatrix/expression/linearin2expr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin + +class LinearIn2Expr(LinearInExprMixin): + pass diff --git a/polymatrix/expression/linearinexpr.py b/polymatrix/expression/linearinexpr.py index 4edf8b3..9054937 100644 --- a/polymatrix/expression/linearinexpr.py +++ b/polymatrix/expression/linearinexpr.py @@ -1,4 +1,4 @@ -from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin +from polymatrix.expression.mixins.oldlinearexprmixin import OldLinearExprMixin -class LinearInExpr(LinearInExprMixin): +class LinearInExpr(OldLinearExprMixin): pass diff --git a/polymatrix/expression/linearmatrixinexpr.py b/polymatrix/expression/linearmatrixinexpr.py index 2bce2e7..296e902 100644 --- a/polymatrix/expression/linearmatrixinexpr.py +++ b/polymatrix/expression/linearmatrixinexpr.py @@ -1,4 +1,4 @@ -from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin +from polymatrix.expression.mixins.filterlinearpartexprmixin import FilterLinearPartExprMixin -class LinearMatrixInExpr(LinearMatrixInExprMixin): +class LinearMatrixInExpr(FilterLinearPartExprMixin): pass diff --git a/polymatrix/expression/mixins/accumulateexprmixin.py b/polymatrix/expression/mixins/accumulateexprmixin.py index 76e1717..47df335 100644 --- a/polymatrix/expression/mixins/accumulateexprmixin.py +++ b/polymatrix/expression/mixins/accumulateexprmixin.py @@ -8,6 +8,7 @@ from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState +# todo: is this needed? class AccumulateExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod @@ -24,7 +25,7 @@ class AccumulateExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/addauxequationsexprmixin.py b/polymatrix/expression/mixins/addauxequationsexprmixin.py index 1dae4f7..e5353fe 100644 --- a/polymatrix/expression/mixins/addauxequationsexprmixin.py +++ b/polymatrix/expression/mixins/addauxequationsexprmixin.py @@ -10,7 +10,7 @@ from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState -# is this really needed? +# todo: is this really needed? class AddAuxEquationsExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod @@ -18,7 +18,7 @@ class AddAuxEquationsExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index d274f2a..2c7a652 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -5,6 +5,7 @@ import dataclass_abc from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState @@ -20,24 +21,49 @@ class AdditionExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.left.shape - # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - assert left.shape == right.shape, f'{left.shape} != {right.shape}' - terms = {} - for underlying in (left, right): + if left.shape == (1, 1): + left, right = right, left + + if left.shape != (1, 1) and right.shape == (1, 1): + + @dataclass_abc.dataclass_abc(frozen=True) + class BroadCastedPolyMatrix(PolyMatrixMixin): + underlying_monomials: tuple[tuple[int], float] + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + return self.underlying_monomials + + try: + underlying_terms = right.get_poly(0, 0) + + except KeyError: + pass + + else: + broadcasted_right = BroadCastedPolyMatrix( + underlying_monomials=underlying_terms, + shape=left.shape, + ) + + all_underlying = (left, broadcasted_right) + + else: + assert left.shape == right.shape, f'{left.shape} != {right.shape}' + + all_underlying = (left, right) + + for underlying in all_underlying: for row in range(left.shape[0]): for col in range(left.shape[1]): diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index 6bdbce1..df4bd80 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -16,7 +16,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index 8779ba1..ba9a612 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -15,7 +15,7 @@ class CacheExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py new file mode 100644 index 0000000..6d0f6eb --- /dev/null +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -0,0 +1,75 @@ + +import abc +import itertools +import math + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +class CombinationsExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def number(self) -> int: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape[1] == 1 + + def gen_monomials(): + for _, monomial_terms in underlying.get_terms(): + assert len(monomial_terms) == 1, f'{monomial_terms} has more than 1 element' + + for monomial, values in monomial_terms.items(): + yield monomial, values + + monomial_terms = tuple(gen_monomials()) + + assert underlying.shape[0] == len(monomial_terms) + + combinations = tuple(itertools.combinations_with_replacement(range(underlying.shape[0]), self.number)) + + # print(combinations) + # print(math.comb(underlying.shape[0]+self.number-1, self.number)) + + terms = {} + + for row, combination in enumerate(combinations): + def gen_combination_monomials(): + for index in combination: + monomial, _ = monomial_terms[index] + yield from monomial + + combination_monomial = tuple(gen_combination_monomials()) + + def acc_combination_value(acc, index): + _, value = monomial_terms[index] + return acc * value + + *_, combination_value = itertools.accumulate( + combination, + acc_combination_value, + initial=1.0, + ) + + terms[row, 0] = {combination_monomial: combination_value} + + poly_matrix = init_poly_matrix( + terms=terms, + shape=(math.comb(underlying.shape[0]+self.number-1, self.number), 1), + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index a551c1f..b403873 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -30,7 +30,7 @@ class DerivativeExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py index 5436ed6..f002f0b 100644 --- a/polymatrix/expression/mixins/determinantexprmixin.py +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -22,7 +22,7 @@ class DeterminantExprMixin(ExpressionBaseMixin): # return self.underlying.shape[0], 1 # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index cfb6bea..dc5f616 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -25,7 +25,7 @@ class DivisionExprMixin(ExpressionBaseMixin): # return self.left.shape # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: @@ -76,7 +76,7 @@ class DivisionExprMixin(ExpressionBaseMixin): state = dataclasses.replace( state, auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms}, - cached_polymatrix=state.cache | {self: poly_matrix}, + cache=state.cache | {self: poly_matrix}, ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index ca5d921..931e022 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -25,7 +25,7 @@ class ElemMultExprMixin(ExpressionBaseMixin): # return self.left.shape # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index d4fdbea..3089bbb 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -26,7 +26,7 @@ class EvalExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: @@ -38,7 +38,7 @@ class EvalExprMixin(ExpressionBaseMixin): values = tuple(self.values[0] for _ in variable_indices) else: - assert len(variable_indices) == len(self.values) + assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}' values = self.values diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py index 3825ed1..a4d7dc2 100644 --- a/polymatrix/expression/mixins/expressionbasemixin.py +++ b/polymatrix/expression/mixins/expressionbasemixin.py @@ -8,10 +8,10 @@ class ExpressionBaseMixin( ): @abc.abstractmethod - def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: ... - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}' + # def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + # assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}' - return self._apply(state) + # return self._apply(state) diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index 4cb4a50..f8b5ef7 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -6,11 +6,13 @@ from sympy import re from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr from polymatrix.expression.init.initadditionexpr import init_addition_expr from polymatrix.expression.init.initcacheexpr import init_cache_expr +from polymatrix.expression.init.initcombinationsexpr import init_combinations_expr from polymatrix.expression.init.initderivativeexpr import init_derivative_expr from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr from polymatrix.expression.init.initdivisionexpr import init_division_expr from polymatrix.expression.init.initelemmultexpr import init_elem_mult_expr from polymatrix.expression.init.initevalexpr import init_eval_expr +from polymatrix.expression.init.initlinearin2expr import init_linear_in2_expr from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr from polymatrix.expression.init.initgetitemexpr import init_get_item_expr @@ -20,6 +22,10 @@ from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr from polymatrix.expression.init.initreshapeexpr import init_reshape_expr +from polymatrix.expression.init.initsqueezeexpr import init_squeeze_expr +from polymatrix.expression.init.initsumexpr import init_sum_expr +from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr +from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr from polymatrix.expression.init.inittransposeexpr import init_transpose_expr @@ -39,7 +45,7 @@ class ExpressionMixin( ... # overwrites abstract method of `PolyMatrixExprBaseMixin` - def _apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: return self.underlying.apply(state) # # overwrites abstract method of `PolyMatrixExprBaseMixin` @@ -131,11 +137,17 @@ class ExpressionMixin( return other @ self def __truediv__(self, other: ExpressionBaseMixin): + match other: + case ExpressionBaseMixin(): + right = other.underlying + case _: + right = init_from_array_expr(other) + return dataclasses.replace( self, underlying=init_division_expr( left=self.underlying, - right=other, + right=right, ), ) @@ -165,31 +177,20 @@ class ExpressionMixin( ), ) - def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin': + def combinations(self, number: int): return dataclasses.replace( self, - underlying=init_parametrize_terms_expr( - name=name, + underlying=init_combinations_expr( underlying=self.underlying, - variables=variables, + number=number, ), ) - def reshape(self, n: int, m: int) -> 'ExpressionMixin': - return dataclasses.replace( - self, - underlying=init_reshape_expr( - underlying=self.underlying, - new_shape=(n, m), - ), - ) - - def rep_mat(self, n: int, m: int) -> 'ExpressionMixin': + def determinant(self) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_rep_mat_expr( + underlying=init_determinant_expr( underlying=self.underlying, - repetition=(n, m), ), ) @@ -207,16 +208,21 @@ class ExpressionMixin( ), ) - def linear_in(self, variables: tuple) -> 'ExpressionMixin': + def eval( + self, + variable: tuple, + value: tuple[float, ...] = None, + ) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_linear_in_expr( + underlying=init_eval_expr( underlying=self.underlying, - variables=variables, + variables=variable, + values=value, ), ) - def linear_matrix_in(self, variable) -> 'ExpressionMixin': + def filter_linear_part(self, variable) -> 'ExpressionMixin': return dataclasses.replace( self, underlying=init_linear_matrix_in_expr( @@ -225,6 +231,25 @@ class ExpressionMixin( ), ) + def linear_in(self, variables: tuple) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_linear_in2_expr( + underlying=self.underlying, + expression=variables, + ), + ) + + def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_parametrize_terms_expr( + name=name, + underlying=self.underlying, + variables=variables, + ), + ) + def quadratic_in(self, variables: tuple) -> 'ExpressionMixin': return dataclasses.replace( self, @@ -234,25 +259,53 @@ class ExpressionMixin( ), ) - def determinant(self) -> 'ExpressionMixin': + def reshape(self, n: int, m: int) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_determinant_expr( + underlying=init_reshape_expr( underlying=self.underlying, + new_shape=(n, m), ), ) - def eval( - self, - variable: tuple, - value: tuple[float, ...] = None, - ) -> 'ExpressionMixin': + def rep_mat(self, n: int, m: int) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_eval_expr( + underlying=init_rep_mat_expr( + underlying=self.underlying, + repetition=(n, m), + ), + ) + + def squeeze(self): + return dataclasses.replace( + self, + underlying=init_squeeze_expr( + underlying=self.underlying, + ), + ) + + def sum(self): + return dataclasses.replace( + self, + underlying=init_sum_expr( + underlying=self.underlying, + ), + ) + + def symmetric(self): + return dataclasses.replace( + self, + underlying=init_symmetric_expr( + underlying=self.underlying, + ), + ) + + def to_constant(self) -> 'ExpressionMixin': + return dataclasses.replace( + self, + underlying=init_to_constant_expr( underlying=self.underlying, - variables=variable, - values=value, ), ) diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py new file mode 100644 index 0000000..fde2017 --- /dev/null +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -0,0 +1,47 @@ + +import abc +import itertools +import dataclass_abc +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +class EyeExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def variable(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + + state, variable = self.variable.apply(state) + + @dataclass_abc.dataclass_abc(frozen=True) + class EyePolyMatrix(PolyMatrixMixin): + shape: tuple[int, int] + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + if max(row, col) <= self.shape[0]: + if row == col: + return {tuple(): 1.0} + + else: + raise KeyError() + + else: + raise Exception(f'{(row, col)=} is out of bounds') + + value = variable.shape[0] + + polymatrix = EyePolyMatrix( + shape=(value, value), + ) + + return state, polymatrix diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 321d955..c8e66ef 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -10,7 +10,8 @@ from polymatrix.expression.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import get_variable_indices -class LinearMatrixInExprMixin(ExpressionBaseMixin): +# is this class needed? +class FilterLinearPartExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: @@ -22,15 +23,22 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, variable_indices = get_variable_indices(state, self.variable) + state, variables = self.variables.apply(state=state) - assert len(variable_indices) == 1 + def gen_variable_monomials(): + for _, term in variables.get_terms(): + assert len(term) == 1, f'{term} should have only a single monomial' + + for monomial in term.keys(): + yield set(monomial) + + variable_monomials = tuple(gen_variable_monomials()) terms = {} @@ -46,13 +54,22 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): for monomial, value in underlying_terms.items(): - x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + for variable_monomial in variable_monomials: + + remainder = list(monomial) + + try: + for variable in variable_monomial: + remainder.remove(variable) + + except ValueError: + continue - # only take linear terms - if len(x_monomial) == 1: - p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + # take the first that matches + if all(variable not in remainder for variable in variable_monomial): + monomial_terms[remainder] += value - monomial_terms[p_monomial] += value + break terms[row, col] = dict(monomial_terms) diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py index e79810c..179cec4 100644 --- a/polymatrix/expression/mixins/fromarrayexprmixin.py +++ b/polymatrix/expression/mixins/fromarrayexprmixin.py @@ -16,7 +16,7 @@ class FromArrayExprMixin(ExpressionBaseMixin): pass # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: @@ -34,6 +34,7 @@ class FromArrayExprMixin(ExpressionBaseMixin): for symbol in poly.gens: state = state.register(key=symbol, n_param=1) + # print(f'{symbol}: {state.n_param}') terms_row_col = {} diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 1af3f11..4ada8ec 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -21,7 +21,7 @@ class FromTermsExprMixin(ExpressionBaseMixin): pass # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index c8ac02e..fa7edb1 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -12,6 +12,12 @@ from polymatrix.expression.expressionstate import ExpressionState class GetItemExprMixin(ExpressionBaseMixin): + @dataclasses.dataclass(frozen=True) + class Slice: + start: int + step: int + stop: int + @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: @@ -19,32 +25,70 @@ class GetItemExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def index(self) -> tuple[int, int]: + def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]: ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) + def get_proper_index(index, shape): + if isinstance(index, tuple): + return index + + elif isinstance(index, GetItemExprMixin.Slice): + if index.start is None: + start = 0 + else: + start = index.start + + if index.stop is None: + stop = shape + else: + stop = index.stop + + if index.step is None: + step = 1 + else: + step = index.step + + return tuple(range(start, stop, step)) + + else: + return (index,) + + proper_index = ( + get_proper_index(self.index[0], underlying.shape[0]), + get_proper_index(self.index[1], underlying.shape[1]), + ) + @dataclass_abc.dataclass_abc(frozen=True) class GetItemPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin index: tuple[int, int] - shape: tuple[int, int] - # aux_terms: tuple + + @property + def shape(self) -> tuple[int, int]: + return (len(self.index[0]), len(self.index[1])) def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - assert row == 0 and col == 0 + try: + n_row = self.index[0][row] + except IndexError: + raise IndexError(f'tuple index {row} out of range given {self.index[0]}') - return self.underlying.get_poly(self.index[0], self.index[1]) + try: + n_col = self.index[1][col] + except IndexError: + raise IndexError(f'tuple index {col} out of range given {self.index[1]}') + + return self.underlying.get_poly(n_row, n_col) return state, GetItemPolyMatrix( underlying=underlying, - shape=(1, 1), - index=self.index, - # aux_terms=underlying.aux_terms, + index=proper_index, )
\ No newline at end of file diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py index 20949a1..d4a3d97 100644 --- a/polymatrix/expression/mixins/kktexprmixin.py +++ b/polymatrix/expression/mixins/kktexprmixin.py @@ -12,6 +12,7 @@ from polymatrix.expression.polymatrix import PolyMatrix from polymatrix.expression.expressionstate import ExpressionState +# todo: remove this? class KKTExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod @@ -39,7 +40,7 @@ class KKTExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 9ea4e92..b690891 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -18,22 +18,32 @@ class LinearInExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: + def expression(self) -> ExpressionBaseMixin: ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - # todo: uncomment this - state, variable_indices = get_variable_indices(state, self.variables) - # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) + state, variables = self.expression.apply(state=state) - terms = {} - idx_row = 0 + def gen_variable_monomials(): + for _, term in variables.get_terms(): + assert len(term) == 1, f'{term} should have only a single monomial' + + for monomial in term.keys(): + yield monomial + + variable_monomials = tuple(gen_variable_monomials()) + all_variables = set(variable for monomial in variable_monomials for variable in monomial) + + # print(variables) + # print(variable_monomials) + + terms = collections.defaultdict(dict) for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): @@ -43,25 +53,33 @@ class LinearInExprMixin(ExpressionBaseMixin): except KeyError: continue - x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float)) - for monomial, value in underlying_terms.items(): - x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) - p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + for variable_index, variable_monomial in enumerate(variable_monomials): + + remainder = list(monomial) + + try: + for variable in variable_monomial: + remainder.remove(variable) + + except ValueError: + continue - assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' + # print(remainder) + # print(variable_monomial) + # print(all(variable not in remainder for variable in all_variables)) - x_monomial_terms[x_monomial][p_monomial] += value + # take the first that matches + if all(variable not in remainder for variable in all_variables): - for data in x_monomial_terms.values(): - terms[idx_row, 0] = dict(data) - idx_row += 1 + terms[row, variable_index][tuple(remainder)] = value + break poly_matrix = init_poly_matrix( - terms=terms, - shape=(idx_row, 1), + terms=dict(terms), + shape=(underlying.shape[0], len(variable_monomials)), ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index f1e1720..1f09fdc 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -20,14 +20,14 @@ class MatrixMultExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - assert left.shape[1] == right.shape[0] + assert left.shape[1] == right.shape[0], f'{left.shape[1]} is not equal to {right.shape[0]}' terms = {} diff --git a/polymatrix/expression/mixins/oldlinearexprmixin.py b/polymatrix/expression/mixins/oldlinearexprmixin.py new file mode 100644 index 0000000..1bdebd9 --- /dev/null +++ b/polymatrix/expression/mixins/oldlinearexprmixin.py @@ -0,0 +1,131 @@ + +import abc +import collections +from numpy import var + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices + + +class OldLinearExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def variables(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + state, variable_indices = get_variable_indices(state, self.variables) + + underlying_terms = underlying.get_terms() + + def gen_variable_terms(): + for _, monomial_term in underlying_terms: + for monomial in monomial_term.keys(): + + x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + yield x_monomial + + variable_terms = tuple(set(gen_variable_terms())) + + terms = {} + + for (row, _), monomial_term in underlying_terms: + + x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float)) + + for monomial, value in monomial_term.items(): + + x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + + assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' + + x_monomial_terms[x_monomial][p_monomial] += value + + for x_monomial, data in x_monomial_terms.items(): + terms[row, variable_terms.index(x_monomial)] = dict(data) + + poly_matrix = init_poly_matrix( + terms=terms, + shape=(underlying.shape[0], len(variable_terms)), + ) + + return state, poly_matrix + + + # for row in range(underlying.shape[0]): + # for col in range(underlying.shape[1]): + + # try: + # underlying_terms = underlying.get_poly(row, col) + # except KeyError: + # continue + + # x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float)) + + # for monomial, value in underlying_terms.items(): + + # x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + # p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + + # assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' + + # x_monomial_terms[x_monomial][p_monomial] += value + + # for data in x_monomial_terms.values(): + # terms[idx_row, 0] = dict(data) + # idx_row += 1 + + # poly_matrix = init_poly_matrix( + # terms=terms, + # shape=(idx_row, 1), + # ) + + # return state, poly_matrix + + # terms = {} + # idx_row = 0 + + # for row in range(underlying.shape[0]): + # for col in range(underlying.shape[1]): + + # try: + # underlying_terms = underlying.get_poly(row, col) + # except KeyError: + # continue + + # x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float)) + + # for monomial, value in underlying_terms.items(): + + # x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices) + # p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) + + # assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' + + # x_monomial_terms[x_monomial][p_monomial] += value + + # for data in x_monomial_terms.values(): + # terms[idx_row, 0] = dict(data) + # idx_row += 1 + + # poly_matrix = init_poly_matrix( + # terms=terms, + # shape=(idx_row, 1), + # ) + + # return state, poly_matrix diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py index 255a5da..d678db5 100644 --- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py +++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py @@ -1,15 +1,12 @@ import abc import dataclasses -import functools -import dataclass_abc -from polymatrix.expression.init.initexpressionstate import init_expression_state from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin -from polymatrix.expression.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin +from polymatrix.expression.utils.getvariableindices import get_variable_indices class ParametrizeTermsExprMixin(ExpressionBaseMixin): @@ -29,7 +26,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: @@ -37,13 +34,14 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin): if self in state.cache: return state, state.cache[self] - # if not hasattr(self, '_terms'): state, underlying = self.underlying.apply(state) - for variable in self.variables: - state = state.register(key=variable, n_param=1) + # for variable in self.variables: + # state = state.register(key=variable, n_param=1) - variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) + # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict) + + state, variable_indices = get_variable_indices(state, self.variables) start_index = state.n_param terms = {} diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py deleted file mode 100644 index 1fff79b..0000000 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ /dev/null @@ -1,68 +0,0 @@ - -import abc - -from polymatrix.expression.init.initpolymatrix import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.polymatrix import PolyMatrix -from polymatrix.expression.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import get_variable_indices - - -class QuadraticInExprMixin(ExpressionBaseMixin): - @property - @abc.abstractmethod - def underlying(self) -> ExpressionBaseMixin: - ... - - @property - @abc.abstractmethod - def variables(self) -> tuple: - ... - - # overwrites abstract method of `ExpressionBaseMixin` - def _apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - state, underlying = self.underlying.apply(state=state) - - assert underlying.shape == (1, 1) - - state, variable_indices = get_variable_indices(state, self.variables) - - terms = {} - - for row in range(underlying.shape[0]): - for col in range(underlying.shape[1]): - - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue - - for monomial, value in underlying_terms.items(): - - x_monomial = tuple(variable_indices.index(var_idx) for var_idx in monomial if var_idx in variable_indices) - p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices) - - assert len(x_monomial) == 2, f'{x_monomial} should be of length 2' - assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted' - - key = tuple(reversed(x_monomial)) - - if key not in terms: - terms[key] = {} - - monomial_terms = terms[key] - - if p_monomial not in monomial_terms: - monomial_terms[p_monomial] = 0 - - monomial_terms[p_monomial] += value - - poly_matrix = init_poly_matrix( - terms=terms, - shape=2*(len(self.variables),), - ) - - return state, poly_matrix diff --git a/polymatrix/expression/mixins/quadraticreprexprmixin.py b/polymatrix/expression/mixins/quadraticreprexprmixin.py new file mode 100644 index 0000000..29ef950 --- /dev/null +++ b/polymatrix/expression/mixins/quadraticreprexprmixin.py @@ -0,0 +1,87 @@ + +import abc +import collections + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices + + +class QuadraticReprExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def variables(self) -> tuple: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape == (1, 1) + + state, variable_indices = get_variable_indices(state, self.variables) + + # state, variables = self.expression.apply(state=state) + + def gen_variable_monomials(): + _, monomials = underlying.get_terms()[0] + + for monomial, value in monomials.items(): + x_monomial = tuple(variable for variable in monomial if variable in variable_indices) + p_monomial = tuple(variable for variable in monomial if variable not in variable_indices) + + assert len(x_monomial) % 2 == 0, f'length of {x_monomial} should be a multiple of 2' + + x_monomial_sorted = tuple(sorted(x_monomial)) + + idx_half = int(len(x_monomial)/2) + + yield x_monomial_sorted[:idx_half], x_monomial_sorted[idx_half:], p_monomial, value + + underlying_monomials = tuple(gen_variable_monomials()) + + def gen_sos_monomials(): + for left_monomial, right_monomial, _, _ in underlying_monomials: + yield left_monomial + yield right_monomial + + sos_monomials = tuple(sorted(set(gen_sos_monomials()), key=lambda m: (len(m), m))) + + # print(f'{underlying_monomials=}') + # print(f'{sos_monomials=}') + + terms = collections.defaultdict(dict) + + for left_monomial, right_monomial, p_monomial, value in underlying_monomials: + + col = sos_monomials.index(left_monomial) + row = sos_monomials.index(right_monomial) + key = (row, col) + + # print(key) + + monomial_terms = terms[key] + + if p_monomial not in monomial_terms: + monomial_terms[p_monomial] = 0 + + monomial_terms[p_monomial] += value + + # print(terms) + + poly_matrix = init_poly_matrix( + terms=dict(terms), + shape=2*(len(sos_monomials),), + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py index 5d5c00f..90c6a66 100644 --- a/polymatrix/expression/mixins/repmatexprmixin.py +++ b/polymatrix/expression/mixins/repmatexprmixin.py @@ -17,7 +17,7 @@ class RepMatExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 886f074..ad2a5a6 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -20,7 +20,7 @@ class ReshapeExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py new file mode 100644 index 0000000..47f2fcf --- /dev/null +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -0,0 +1,53 @@ + +import abc +import collections +import typing +import dataclass_abc + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +class SqueezeExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape[1] == 1 + + terms = {} + row_index = 0 + + for row in range(underlying.shape[0]): + + try: + underlying_terms = underlying.get_poly(row, 0) + except KeyError: + continue + + terms_row_col = {} + + for monomial, value in underlying_terms.items(): + if value != 0.0: + terms_row_col[monomial] = value + + if len(terms_row_col): + terms[row_index, 0] = terms_row_col + row_index += 1 + + poly_matrix = init_poly_matrix( + terms=terms, + shape=(row_index, 1), + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py new file mode 100644 index 0000000..96bf3da --- /dev/null +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -0,0 +1,49 @@ + +import abc +import collections +import dataclasses + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + + +class SumExprMixin(ExpressionBaseMixin): + @property + @abc.abstractclassmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionStateMixin, + ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + + state, underlying = self.underlying.apply(state) + + terms = collections.defaultdict(dict) + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + term_monomials = terms[row, 0] + + for monomial, coeff in underlying_terms.items(): + if monomial in term_monomials: + term_monomials[monomial] += coeff + else: + term_monomials[monomial] = coeff + + poly_matrix = init_poly_matrix( + terms=terms, + shape=(underlying.shape[0], 1), + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py new file mode 100644 index 0000000..e914895 --- /dev/null +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -0,0 +1,68 @@ + +import abc +import itertools +import dataclass_abc +from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState + + +class SymmetricExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape[0] == underlying.shape[1] + + @dataclass_abc.dataclass_abc(frozen=True) + class SymmetricPolyMatrix(PolyMatrixMixin): + underlying: PolyMatrixMixin + + @property + def shape(self) -> tuple[int, int]: + return self.underlying.shape + + def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: + + def gen_symmetric_monomials(): + for i_row, i_col in ((row, col), (col, row)): + try: + monomials = self.underlying.get_poly(i_row, i_col) + except: + pass + else: + yield monomials + + all_monomials = tuple(gen_symmetric_monomials()) + + if len(all_monomials) == 0: + raise KeyError() + + else: + terms = {} + + # merge monomials + for monomials in all_monomials: + for monomial, value in monomials.items(): + if monomial in terms: + terms[monomial] = (terms[monomial] + value) / 2 + else: + terms[monomial] = value + return terms + + polymat = SymmetricPolyMatrix( + underlying=underlying, + ) + + return state, polymat diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py new file mode 100644 index 0000000..1276251 --- /dev/null +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -0,0 +1,44 @@ + +import abc +import collections +from numpy import var + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices + + +class ToConstantExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: + state, underlying = self.underlying.apply(state=state) + + terms = collections.defaultdict(dict) + + for row in range(underlying.shape[0]): + for col in range(underlying.shape[1]): + + try: + underlying_terms = underlying.get_poly(row, col) + except KeyError: + continue + + if tuple() in underlying_terms: + terms[row, col][tuple()] = underlying_terms[tuple()] + + poly_matrix = init_poly_matrix( + terms=dict(terms), + shape=underlying.shape, + ) + + return state, poly_matrix diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py index 0aac30a..2b62a89 100644 --- a/polymatrix/expression/mixins/toquadraticexprmixin.py +++ b/polymatrix/expression/mixins/toquadraticexprmixin.py @@ -16,7 +16,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py index fa074d7..da1c288 100644 --- a/polymatrix/expression/mixins/transposeexprmixin.py +++ b/polymatrix/expression/mixins/transposeexprmixin.py @@ -18,7 +18,7 @@ class TransposeExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `PolyMatrixExprBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py index 90cfb20..b27828e 100644 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ b/polymatrix/expression/mixins/vstackexprmixin.py @@ -16,7 +16,7 @@ class VStackExprMixin(ExpressionBaseMixin): ... # overwrites abstract method of `ExpressionBaseMixin` - def _apply( + def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: @@ -26,7 +26,8 @@ class VStackExprMixin(ExpressionBaseMixin): state, polymat = expr.apply(state=state) all_underlying.append(polymat) - assert all(underlying.shape[1] == all_underlying[0].shape[1] for underlying in all_underlying) + for underlying in all_underlying: + assert underlying.shape[1] == all_underlying[0].shape[1], f'{underlying.shape[1]} not equal {all_underlying[0].shape[1]}' @dataclass_abc.dataclass_abc(frozen=True) class VStackPolyMatrix(PolyMatrixMixin): diff --git a/polymatrix/expression/quadraticinexpr.py b/polymatrix/expression/quadraticinexpr.py index bab76f6..59f452b 100644 --- a/polymatrix/expression/quadraticinexpr.py +++ b/polymatrix/expression/quadraticinexpr.py @@ -1,4 +1,4 @@ -from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin +from polymatrix.expression.mixins.quadraticreprexprmixin import QuadraticReprExprMixin -class QuadraticInExpr(QuadraticInExprMixin): +class QuadraticInExpr(QuadraticReprExprMixin): pass diff --git a/polymatrix/expression/squeezeexpr.py b/polymatrix/expression/squeezeexpr.py new file mode 100644 index 0000000..5472764 --- /dev/null +++ b/polymatrix/expression/squeezeexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin + +class SqueezeExpr(SqueezeExprMixin): + pass diff --git a/polymatrix/expression/sumexpr.py b/polymatrix/expression/sumexpr.py new file mode 100644 index 0000000..7b62e59 --- /dev/null +++ b/polymatrix/expression/sumexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.sumexprmixin import SumExprMixin + +class SumExpr(SumExprMixin): + pass diff --git a/polymatrix/expression/symmetricexpr.py b/polymatrix/expression/symmetricexpr.py new file mode 100644 index 0000000..ffbaa90 --- /dev/null +++ b/polymatrix/expression/symmetricexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin + +class SymmetricExpr(SymmetricExprMixin): + pass diff --git a/polymatrix/expression/toconstantexpr.py b/polymatrix/expression/toconstantexpr.py new file mode 100644 index 0000000..d1fb2a9 --- /dev/null +++ b/polymatrix/expression/toconstantexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin + +class ToConstantExpr(ToConstantExprMixin): + pass diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 46bb51b..71a91f6 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -14,16 +14,18 @@ def get_variable_indices(state, variables): for row in range(variables.shape[0]): row_terms = variables.get_poly(row, 0) - assert len(row_terms) == 1 + assert len(row_terms) == 1, f'{row_terms} contains more than one term' for monomial in row_terms.keys(): - assert len(monomial) == 1 + assert len(monomial) == 1, f'{monomial} contains more than one variable' yield monomial[0] return state, tuple(gen_indices()) else: + raise Exception('not supported anymore') + if not isinstance(variables, tuple): variables = (variables,) diff --git a/polymatrix/sympyutils.py b/polymatrix/sympyutils.py index ac91b8f..3e727ed 100644 --- a/polymatrix/sympyutils.py +++ b/polymatrix/sympyutils.py @@ -18,7 +18,7 @@ def poly_to_data_coord(poly_list, x, degree = None): sympy_poly_list = tuple(tuple(sympy.poly(p, x) for p in inner_poly_list) for inner_poly_list in poly_list) if degree is None: - degree = max(degree for inner_poly_list in sympy_poly_list for poly in inner_poly_list for degree in poly.degree_list()) + degree = max(sum(monom) for inner_poly_list in sympy_poly_list for poly in inner_poly_list for monom in poly.monoms()) n = len(x) |