diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-04 14:47:50 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-04 14:47:50 +0200 |
commit | 561997eaa73ad81ec2da63e6e62cdf776d87be4e (patch) | |
tree | 6e37072a76772e70738f3572b63b5836688eeeb3 /polymatrix/expression/mixins/derivativeexprmixin.py | |
parent | bugfixes in KKT conditions (diff) | |
download | polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.tar.gz polymatrix-561997eaa73ad81ec2da63e6e62cdf776d87be4e.zip |
implement polynomial matrix as an expression
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/derivativeexprmixin.py | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py new file mode 100644 index 0000000..516aa38 --- /dev/null +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -0,0 +1,155 @@ + +import abc +import collections +import typing +import dataclass_abc +from numpy import var +from polymatrix.expression.init.initderivativekey import init_derivative_key + +from polymatrix.expression.init.initpolymatrix import init_poly_matrix +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.polymatrix import PolyMatrix +from polymatrix.polymatrixexprstate import PolyMatrixExprState + + +class DerivativeExprMixin(ExpressionBaseMixin): + @property + @abc.abstractmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + @property + @abc.abstractmethod + def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]: + ... + + @property + @abc.abstractmethod + def introduce_derivatives(self) -> bool: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + @property + def shape(self) -> tuple[int, int]: + match self.variables: + case ExpressionBaseMixin(): + n_cols = self.variables.shape[0] + + case _: + n_cols = len(self.variables) + + return self.underlying.shape[0], n_cols + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: PolyMatrixExprState, + ) -> tuple[PolyMatrixExprState, PolyMatrix]: + state = [state] + + state[0], underlying = self.underlying.apply(state=state[0]) + + match self.variables: + case ExpressionBaseMixin(): + assert self.variables.shape[1] == 1 + + state[0], variables = self.variables.apply(state[0]) + + def gen_indices(): + for row in range(variables.shape[0]): + for monomial in variables.get_poly(row, 0).keys(): + yield monomial[0] + + variable_indices = tuple(sorted(gen_indices())) + # print(f'{variable_indices=}') + + case _: + def gen_indices(): + for variable in self.variable: + if variable in state[0].offset_dict: + yield state[0].offset_dict[variable][0] + + variable_indices = tuple(sorted(gen_indices())) + + terms = {} + # aux_terms = [] + + for var_idx, variable in enumerate(variable_indices): + + def get_derivative_terms(monomial_terms): + + terms_row_col = {} + + for monomial, value in monomial_terms.items(): + + # count powers for each variable + monomial_cnt = dict(collections.Counter(monomial)) + + if variable not in monomial_cnt: + continue + + if self.introduce_derivatives: + variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices) + else: + variable_candidates = (variable,) + + for variable_candidate in variable_candidates: + + def generate_monomial(): + for current_variable, current_count in monomial_cnt.items(): + + if current_variable is variable_candidate: + sel_counter = current_count - 1 + + else: + sel_counter = current_count + + for _ in range(sel_counter): + yield current_variable + + if variable_candidate is not variable: + key = init_derivative_key( + variable=variable_candidate, + with_respect_to=var_idx, + ) + state[0] = state[0].register(key=key, n_param=1) + + yield state[0].offset_dict[key][0] + + col_monomial = tuple(generate_monomial()) + + if col_monomial not in terms_row_col: + terms_row_col[col_monomial] = 0 + + terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate] + + return terms_row_col + + for row in range(self.shape[0]): + + try: + underlying_terms = underlying.get_poly(row, 0) + except KeyError: + continue + + derivative_terms = get_derivative_terms(underlying_terms) + + if 0 < len(derivative_terms): + terms[row, var_idx] = derivative_terms + + # if self.introduce_derivatives: + # for aux_monomial in underlying.aux_terms: + + # derivative_terms = get_derivative_terms(aux_monomial) + + # # todo: is this correct? + # if 1 < len(derivative_terms): + # aux_terms.append(derivative_terms) + + poly_matrix = init_poly_matrix( + terms=terms, + shape=self.shape, + # aux_terms=underlying.aux_terms + tuple(aux_terms), + ) + + return state[0], poly_matrix |