summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-04 14:47:50 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-04 14:47:50 +0200
commit561997eaa73ad81ec2da63e6e62cdf776d87be4e (patch)
tree6e37072a76772e70738f3572b63b5836688eeeb3 /polymatrix/expression/mixins/derivativeexprmixin.py
parentbugfixes in KKT conditions (diff)
downloadpolymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.tar.gz
polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.zip
implement polynomial matrix as an expression
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
new file mode 100644
index 0000000..516aa38
--- /dev/null
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -0,0 +1,155 @@
+
+import abc
+import collections
+import typing
+import dataclass_abc
+from numpy import var
+from polymatrix.expression.init.initderivativekey import init_derivative_key
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DerivativeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def introduce_derivatives(self) -> bool:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ match self.variables:
+ case ExpressionBaseMixin():
+ n_cols = self.variables.shape[0]
+
+ case _:
+ n_cols = len(self.variables)
+
+ return self.underlying.shape[0], n_cols
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state = [state]
+
+ state[0], underlying = self.underlying.apply(state=state[0])
+
+ match self.variables:
+ case ExpressionBaseMixin():
+ assert self.variables.shape[1] == 1
+
+ state[0], variables = self.variables.apply(state[0])
+
+ def gen_indices():
+ for row in range(variables.shape[0]):
+ for monomial in variables.get_poly(row, 0).keys():
+ yield monomial[0]
+
+ variable_indices = tuple(sorted(gen_indices()))
+ # print(f'{variable_indices=}')
+
+ case _:
+ def gen_indices():
+ for variable in self.variable:
+ if variable in state[0].offset_dict:
+ yield state[0].offset_dict[variable][0]
+
+ variable_indices = tuple(sorted(gen_indices()))
+
+ terms = {}
+ # aux_terms = []
+
+ for var_idx, variable in enumerate(variable_indices):
+
+ def get_derivative_terms(monomial_terms):
+
+ terms_row_col = {}
+
+ for monomial, value in monomial_terms.items():
+
+ # count powers for each variable
+ monomial_cnt = dict(collections.Counter(monomial))
+
+ if variable not in monomial_cnt:
+ continue
+
+ if self.introduce_derivatives:
+ variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
+ else:
+ variable_candidates = (variable,)
+
+ for variable_candidate in variable_candidates:
+
+ def generate_monomial():
+ for current_variable, current_count in monomial_cnt.items():
+
+ if current_variable is variable_candidate:
+ sel_counter = current_count - 1
+
+ else:
+ sel_counter = current_count
+
+ for _ in range(sel_counter):
+ yield current_variable
+
+ if variable_candidate is not variable:
+ key = init_derivative_key(
+ variable=variable_candidate,
+ with_respect_to=var_idx,
+ )
+ state[0] = state[0].register(key=key, n_param=1)
+
+ yield state[0].offset_dict[key][0]
+
+ col_monomial = tuple(generate_monomial())
+
+ if col_monomial not in terms_row_col:
+ terms_row_col[col_monomial] = 0
+
+ terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ return terms_row_col
+
+ for row in range(self.shape[0]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ derivative_terms = get_derivative_terms(underlying_terms)
+
+ if 0 < len(derivative_terms):
+ terms[row, var_idx] = derivative_terms
+
+ # if self.introduce_derivatives:
+ # for aux_monomial in underlying.aux_terms:
+
+ # derivative_terms = get_derivative_terms(aux_monomial)
+
+ # # todo: is this correct?
+ # if 1 < len(derivative_terms):
+ # aux_terms.append(derivative_terms)
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ # aux_terms=underlying.aux_terms + tuple(aux_terms),
+ )
+
+ return state[0], poly_matrix