summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/determinantexprmixin.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-04 14:47:50 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-04 14:47:50 +0200
commit561997eaa73ad81ec2da63e6e62cdf776d87be4e (patch)
tree6e37072a76772e70738f3572b63b5836688eeeb3 /polymatrix/expression/mixins/determinantexprmixin.py
parentbugfixes in KKT conditions (diff)
downloadpolymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.tar.gz
polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.zip
implement polynomial matrix as an expression
Diffstat (limited to 'polymatrix/expression/mixins/determinantexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
new file mode 100644
index 0000000..8fe2536
--- /dev/null
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -0,0 +1,115 @@
+
+import abc
+import collections
+import dataclasses
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DeterminantExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape[0], 1
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ if self in state.cached_polymatrix:
+ return state, state.cached_polymatrix[self]
+
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[0] == underlying.shape[1]
+
+ inequality_terms = {}
+ auxillary_terms = []
+
+ index_start = state.n_param
+ rel_index = 0
+
+ for row in range(self.shape[0]):
+
+ current_inequality_terms = collections.defaultdict(float)
+
+ # f in f-v^T@x-r^2
+ # terms = underlying.get_poly(row, row)
+ try:
+ underlying_terms = underlying.get_poly(row, row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ current_inequality_terms[monomial] += value
+
+ for inner_row in range(row):
+
+ # -v^T@x in f-v^T@x-r^2
+ # terms = underlying.get_poly(row, inner_row)
+ try:
+ underlying_terms = underlying.get_poly(row, inner_row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (index_start + rel_index + inner_row,)
+ current_inequality_terms[new_monomial] -= value
+
+ auxillary_term = collections.defaultdict(float)
+
+ for inner_col in range(row):
+
+ # P@x in P@x-v
+ key = tuple(reversed(sorted((inner_row, inner_col))))
+ # terms = underlying.get_poly(*key)
+ try:
+ underlying_terms = underlying.get_poly(*key)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (index_start + rel_index + inner_col,)
+ auxillary_term[new_monomial] += value
+
+ # -v in P@x-v
+ # terms = underlying.get_poly(row, inner_row)
+ try:
+ underlying_terms = underlying.get_poly(row, inner_row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ auxillary_term[monomial] -= value
+
+ auxillary_terms.append(dict(auxillary_term))
+
+ rel_index += row
+ inequality_terms[row, 0] = dict(current_inequality_terms)
+
+ state = state.register(rel_index)
+
+ # print(f'{auxillary_terms=}')
+
+ poly_matrix = init_poly_matrix(
+ terms=inequality_terms,
+ shape=self.shape,
+ )
+
+ state = dataclasses.replace(
+ state,
+ auxillary_terms=state.auxillary_terms + tuple(auxillary_terms),
+ cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ )
+
+ return state, poly_matrix