diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-04 14:47:50 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-04 14:47:50 +0200 |
commit | 561997eaa73ad81ec2da63e6e62cdf776d87be4e (patch) | |
tree | 6e37072a76772e70738f3572b63b5836688eeeb3 /polymatrix/expression/mixins/determinantexprmixin.py | |
parent | bugfixes in KKT conditions (diff) | |
download | polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.tar.gz polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.zip |
implement polynomial matrix as an expression
Diffstat (limited to 'polymatrix/expression/mixins/determinantexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/determinantexprmixin.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py new file mode 100644 index 0000000..8fe2536 --- /dev/null +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -0,0 +1,115 @@ + +import abc +import collections +import dataclasses +from numpy import var + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.polymatrixexprstate import PolyMatrixExprState + + +class DeterminantExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + @property + def shape(self) -> tuple[int, int]: + return self.underlying.shape[0], 1 + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: PolyMatrixExprState, + ) -> tuple[PolyMatrixExprState, PolyMatrix]: + if self in state.cached_polymatrix: + return state, state.cached_polymatrix[self] + + state, underlying = self.underlying.apply(state=state) + + assert underlying.shape[0] == underlying.shape[1] + + inequality_terms = {} + auxillary_terms = [] + + index_start = state.n_param + rel_index = 0 + + for row in range(self.shape[0]): + + current_inequality_terms = collections.defaultdict(float) + + # f in f-v^T@x-r^2 + # terms = underlying.get_poly(row, row) + try: + underlying_terms = underlying.get_poly(row, row) + except KeyError: + pass + else: + for monomial, value in underlying_terms.items(): + current_inequality_terms[monomial] += value + + for inner_row in range(row): + + # -v^T@x in f-v^T@x-r^2 + # terms = underlying.get_poly(row, inner_row) + try: + underlying_terms = underlying.get_poly(row, inner_row) + except KeyError: + pass + else: + for monomial, value in underlying_terms.items(): + new_monomial = monomial + (index_start + rel_index + inner_row,) + current_inequality_terms[new_monomial] -= value + + auxillary_term = collections.defaultdict(float) + + for inner_col in range(row): + + # P@x in P@x-v + key = tuple(reversed(sorted((inner_row, inner_col)))) + # terms = underlying.get_poly(*key) + try: + underlying_terms = underlying.get_poly(*key) + except KeyError: + pass + else: + for monomial, value in underlying_terms.items(): + new_monomial = monomial + (index_start + rel_index + inner_col,) + auxillary_term[new_monomial] += value + + # -v in P@x-v + # terms = underlying.get_poly(row, inner_row) + try: + underlying_terms = underlying.get_poly(row, inner_row) + except KeyError: + pass + else: + for monomial, value in underlying_terms.items(): + auxillary_term[monomial] -= value + + auxillary_terms.append(dict(auxillary_term)) + + rel_index += row + inequality_terms[row, 0] = dict(current_inequality_terms) + + state = state.register(rel_index) + + # print(f'{auxillary_terms=}') + + poly_matrix = init_poly_matrix( + terms=inequality_terms, + shape=self.shape, + ) + + state = dataclasses.replace( + state, + auxillary_terms=state.auxillary_terms + tuple(auxillary_terms), + cached_polymatrix=state.cached_polymatrix | {self: poly_matrix}, + ) + + return state, poly_matrix |