summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-08 08:51:19 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-08 08:51:19 +0200
commit19ab60ad70e7d8478609a47a2e937fef3720dac9 (patch)
tree8836fbae02718948a45a6bd2e4c3d1f55548c7c5 /polymatrix/expression/mixins/derivativeexprmixin.py
parentimplement polynomial matrix as an expression (diff)
downloadpolymatrix-19ab60ad70e7d8478609a47a2e937fef3720dac9.tar.gz
polymatrix-19ab60ad70e7d8478609a47a2e937fef3720dac9.zip
remove shape property, introduce accumulate function, reimplement derivative expression
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py503
1 files changed, 427 insertions, 76 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 516aa38..6098aa9 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -1,9 +1,11 @@
import abc
import collections
+import dataclasses
+import itertools
import typing
import dataclass_abc
-from numpy import var
+from numpy import nonzero, var
from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
@@ -28,128 +30,477 @@ class DerivativeExprMixin(ExpressionBaseMixin):
def introduce_derivatives(self) -> bool:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- match self.variables:
- case ExpressionBaseMixin():
- n_cols = self.variables.shape[0]
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # match self.variables:
+ # case ExpressionBaseMixin():
+ # n_cols = self.variables.shape[0]
- case _:
- n_cols = len(self.variables)
+ # case _:
+ # n_cols = len(self.variables)
- return self.underlying.shape[0], n_cols
+ # return self.underlying.shape[0], n_cols
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
- state = [state]
-
- state[0], underlying = self.underlying.apply(state=state[0])
match self.variables:
case ExpressionBaseMixin():
- assert self.variables.shape[1] == 1
+ state, variables = self.variables.apply(state)
- state[0], variables = self.variables.apply(state[0])
+ assert variables.shape[1] == 1
def gen_indices():
for row in range(variables.shape[0]):
- for monomial in variables.get_poly(row, 0).keys():
+ row_terms = variables.get_poly(row, 0)
+
+ assert len(row_terms) == 1
+
+ for monomial in row_terms.keys():
+ assert len(monomial) == 1
yield monomial[0]
- variable_indices = tuple(sorted(gen_indices()))
- # print(f'{variable_indices=}')
+ diff_wrt_variables = tuple(gen_indices())
case _:
def gen_indices():
- for variable in self.variable:
- if variable in state[0].offset_dict:
- yield state[0].offset_dict[variable][0]
+ for variable in self.variables:
+ if variable in state.offset_dict:
+ yield state.offset_dict[variable][0]
- variable_indices = tuple(sorted(gen_indices()))
+ diff_wrt_variables = tuple(gen_indices())
- terms = {}
- # aux_terms = []
+ def get_derivative_terms(
+ monomial_terms,
+ diff_wrt_variable: int,
+ state: PolyMatrixExprState,
+ considered_variables: set,
+ ):
- for var_idx, variable in enumerate(variable_indices):
-
- def get_derivative_terms(monomial_terms):
+ if self.introduce_derivatives:
+
+ def gen_new_variables():
+ for monomial in monomial_terms.keys():
+ for var in monomial:
+ if var not in diff_wrt_variables and var not in considered_variables:
+ yield var
+
+ new_variables = set(gen_new_variables())
+
+ new_considered_variables = considered_variables | new_variables
- terms_row_col = {}
+ def acc_state_candidates(acc, new_variable):
+ state, candidates = acc
- for monomial, value in monomial_terms.items():
+ key = init_derivative_key(
+ variable=new_variable,
+ with_respect_to=diff_wrt_variable,
+ )
+ state = state.register(key=key, n_param=1)
- # count powers for each variable
- monomial_cnt = dict(collections.Counter(monomial))
+ state, auxillary_derivation_terms = get_derivative_terms(
+ monomial_terms=state.auxillary_equations[new_variable],
+ diff_wrt_variable=diff_wrt_variable,
+ state=state,
+ considered_variables=new_considered_variables,
+ )
- if variable not in monomial_cnt:
- continue
+ if 1 < len(auxillary_derivation_terms):
+ derivation_variable = state.offset_dict[key][0]
+
+ state = dataclasses.replace(
+ state,
+ auxillary_equations=state.auxillary_equations | {derivation_variable: auxillary_derivation_terms},
+ )
+
+ return state, candidates + (new_variable,)
- if self.introduce_derivatives:
- variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
else:
- variable_candidates = (variable,)
-
- for variable_candidate in variable_candidates:
+ return state, candidates
- def generate_monomial():
- for current_variable, current_count in monomial_cnt.items():
+ *_, (state, confirmed_variables) = itertools.accumulate(
+ new_variables,
+ acc_state_candidates,
+ initial=(state, tuple()),
+ )
- if current_variable is variable_candidate:
- sel_counter = current_count - 1
+ else:
+ confirmed_variables = tuple()
- else:
- sel_counter = current_count
+ derivation_terms = collections.defaultdict(float)
- for _ in range(sel_counter):
- yield current_variable
+ for monomial, value in monomial_terms.items():
- if variable_candidate is not variable:
- key = init_derivative_key(
- variable=variable_candidate,
- with_respect_to=var_idx,
- )
- state[0] = state[0].register(key=key, n_param=1)
+ # count powers for each variable
+ monomial_cnt = dict(collections.Counter(monomial))
- yield state[0].offset_dict[key][0]
+ def differentiate_monomial(dependent_variable, derivation_variable=None):
+ def gen_diff_monomial():
+ for current_variable, current_count in monomial_cnt.items():
- col_monomial = tuple(generate_monomial())
+ if current_variable is dependent_variable:
+ sel_counter = current_count - 1
- if col_monomial not in terms_row_col:
- terms_row_col[col_monomial] = 0
+ else:
+ sel_counter = current_count
- terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+ for _ in range(sel_counter):
+ yield current_variable
- return terms_row_col
+ if derivation_variable is not None:
+ yield derivation_variable
- for row in range(self.shape[0]):
+ diff_monomial = tuple(sorted(gen_diff_monomial()))
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
- continue
+ return diff_monomial, value * monomial_cnt[dependent_variable]
- derivative_terms = get_derivative_terms(underlying_terms)
+ if diff_wrt_variable in monomial_cnt:
+ diff_monomial, value = differentiate_monomial(diff_wrt_variable)
+ derivation_terms[diff_monomial] += value
- if 0 < len(derivative_terms):
- terms[row, var_idx] = derivative_terms
+ for candidate_variable in monomial_cnt.keys():
+ if candidate_variable in considered_variables or candidate_variable in confirmed_variables:
+ key = init_derivative_key(
+ variable=candidate_variable,
+ with_respect_to=diff_wrt_variable,
+ )
+ derivation_variable = state.offset_dict[key][0]
- # if self.introduce_derivatives:
- # for aux_monomial in underlying.aux_terms:
+ diff_monomial, value = differentiate_monomial(
+ dependent_variable=candidate_variable,
+ derivation_variable=derivation_variable,
+ )
+ derivation_terms[diff_monomial] += value
- # derivative_terms = get_derivative_terms(aux_monomial)
+ return state, dict(derivation_terms)
+
+ state, underlying = self.underlying.apply(state=state)
+
+ terms = {}
+
+ for row in range(underlying.shape[0]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ # derivate each variable and map result to the corresponding column
+ for col, diff_wrt_variable in enumerate(diff_wrt_variables):
+
+ state, derivation_terms = get_derivative_terms(
+ monomial_terms=underlying_terms,
+ diff_wrt_variable=diff_wrt_variable,
+ state=state,
+ considered_variables=set(),
+ )
+
+ if 0 < len(derivation_terms):
+ terms[row, col] = derivation_terms
- # # todo: is this correct?
- # if 1 < len(derivative_terms):
- # aux_terms.append(derivative_terms)
-
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
- # aux_terms=underlying.aux_terms + tuple(aux_terms),
+ shape=(underlying.shape[0], len(diff_wrt_variables)),
)
- return state[0], poly_matrix
+ return state, poly_matrix
+
+ # def get_derivative_terms(
+ # monomial_terms,
+ # diff_wrt_variable: int,
+ # state: PolyMatrixExprState,
+ # considered_variables: set,
+ # # implement_derivation: bool,
+ # ):
+ # derivation_terms = collections.defaultdict(float)
+
+ # other_independent_variables = tuple(var for var in diff_wrt_variables if var is not diff_wrt_variable)
+
+ # # print(other_independent_variables)
+ # # print(tuple(variable for monomial in monomial_terms.keys() for variable in monomial))
+
+ # if sum(variable not in other_independent_variables for monomial in monomial_terms.keys() for variable in monomial) < 2:
+ # return {}, state
+
+ # # if not implement_derivation:
+ # # implement_derivation = any(diff_wrt_variable in monomial for monomial in monomial_terms.keys())
+
+ # for monomial, value in monomial_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(monomial))
+
+ # def differentiate_monomial(dependent_variable, derivation_variable=None):
+ # def gen_diff_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is dependent_variable:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivation_variable is not None:
+ # yield derivation_variable
+
+ # diff_monomial = tuple(gen_diff_monomial())
+
+ # return diff_monomial, value * monomial_cnt[dependent_variable]
+
+ # if diff_wrt_variable in monomial_cnt:
+ # diff_monomial, value = differentiate_monomial(diff_wrt_variable)
+ # derivation_terms[diff_monomial] += value
+
+ # if self.introduce_derivatives:
+
+ # def gen_derivation_keys():
+ # for variable in monomial_cnt.keys():
+ # if variable not in diff_wrt_variables:
+ # yield variable
+
+ # candidate_variables = tuple(gen_derivation_keys())
+
+ # new_considered_derivations = considered_variables | set(candidate_variables)
+
+ # for candidate_variable in candidate_variables:
+
+ # # introduce new auxillary equation
+ # if candidate_variable not in considered_variables:
+ # auxillary_derivation_terms, state = get_derivative_terms(
+ # monomial_terms=state.auxillary_equations[candidate_variable],
+ # diff_wrt_variable=diff_wrt_variable,
+ # state=state,
+ # considered_variables=new_considered_derivations,
+ # # implement_derivation=implement_derivation,
+ # )
+
+ # if 0 < len(auxillary_derivation_terms):
+ # key = init_derivative_key(
+ # variable=candidate_variable,
+ # with_respect_to=diff_wrt_variable,
+ # )
+ # state = state.register(key=key, n_param=1)
+ # derivation_variable = state.offset_dict[key][0]
+
+ # state = dataclasses.replace(
+ # state,
+ # auxillary_equations=state.auxillary_equations | {derivation_variable: auxillary_derivation_terms},
+ # )
+
+ # else:
+
+ # key = init_derivative_key(
+ # variable=candidate_variable,
+ # with_respect_to=diff_wrt_variable,
+ # )
+ # state = state.register(key=key, n_param=1)
+ # derivation_variable = state.offset_dict[key][0]
+
+ # diff_monomial, value = differentiate_monomial(
+ # dependent_variable=candidate_variable,
+ # derivation_variable=derivation_variable,
+ # )
+ # derivation_terms[diff_monomial] += value
+
+ # return dict(derivation_terms), state
+
+ # terms = {}
+
+ # for row in range(self.shape[0]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, 0)
+ # except KeyError:
+ # continue
+
+ # # derivate each variable and map result to the corresponding column
+ # for col, diff_wrt_variable in enumerate(diff_wrt_variables):
+
+ # derivation_terms, state = get_derivative_terms(
+ # monomial_terms=underlying_terms,
+ # diff_wrt_variable=diff_wrt_variable,
+ # state=state,
+ # considered_variables=set(),
+ # # implement_derivation=False,
+ # )
+
+ # if 0 < len(derivation_terms):
+ # terms[row, col] = derivation_terms
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=self.shape,
+ # )
+
+ # return state, poly_matrix
+
+
+
+ # state = [state]
+
+ # state, underlying = self.underlying.apply(state=state)
+
+ # match self.variables:
+ # case ExpressionBaseMixin():
+ # assert self.variables.shape[1] == 1
+
+ # state, dependent_variables = self.variables.apply(state)
+
+ # def gen_indices():
+ # for row in range(dependent_variables.shape[0]):
+ # for monomial in dependent_variables.get_poly(row, 0).keys():
+ # yield monomial[0]
+
+ # variable_indices = tuple(sorted(gen_indices()))
+
+ # case _:
+ # def gen_indices():
+ # for variable in self.variables:
+ # if variable in state.offset_dict:
+ # yield state.offset_dict[variable][0]
+
+ # variable_indices = tuple(sorted(gen_indices()))
+
+ # terms = {}
+ # derivations_keys = set()
+
+ # # derivate each variable and map result to the corresponding column
+ # for col, derivation_variable in enumerate(variable_indices):
+
+ # def get_derivative_terms(monomial_terms):
+
+ # terms_row_col = collections.defaultdict(float)
+
+ # for monomial, value in monomial_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(monomial))
+
+ # variable_candidates = tuple()
+
+ # if derivation_variable in monomial_cnt:
+ # variable_candidates += ((derivation_variable, None),)
+
+ # if self.introduce_derivatives:
+ # def gen_dependent_variables():
+ # for dependent_variable in monomial_cnt.keys():
+ # if dependent_variable not in variable_indices:
+ # derivation_key = init_derivative_key(
+ # variable=dependent_variable,
+ # with_respect_to=derivation_variable,
+ # )
+ # derivations_keys.add(derivation_key)
+ # state = state.register(key=derivation_key, n_param=1)
+ # yield dependent_variable, derivation_key
+
+ # variable_candidates += tuple(gen_dependent_variables())
+
+ # for variable_candidate, derivation_key in variable_candidates:
+
+ # def generate_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is variable_candidate:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivation_key is not None:
+ # yield state.offset_dict[derivation_key][0]
+
+ # col_monomial = tuple(generate_monomial())
+
+ # terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ # return dict(terms_row_col)
+
+ # for row in range(self.shape[0]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, 0)
+ # except KeyError:
+ # continue
+
+ # derivative_terms = get_derivative_terms(underlying_terms)
+
+ # if 0 < len(derivative_terms):
+ # terms[row, col] = derivative_terms
+
+ # derivation_variables = collections.defaultdict(list)
+ # for derivation_key in derivations_keys:
+ # derivation_variables[derivation_key.with_respect_to].append(derivation_key)
+
+ # aux_der_terms = []
+
+ # for derivation_variable, derivation_keys in derivation_variables.items():
+
+ # dependent_variables = tuple(derivation_key.variable for derivation_key in derivation_keys)
+
+
+ # for aux_terms in state.auxillary_equations:
+
+ # # only intoduce a new auxillary equation if there is a monomial containing at least one dependent variable
+ # if any(variable in dependent_variables for monomial in aux_terms.keys() for variable in monomial):
+
+ # terms_row_col = collections.defaultdict(float)
+
+ # # for each monomial
+ # for aux_monomial, value in aux_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(aux_monomial))
+
+ # variable_candidates = tuple()
+
+ # if derivation_variable in monomial_cnt:
+ # variable_candidates += ((derivation_variable, None),)
+
+ # # add dependent variables
+ # variable_candidates += tuple((derivation_key.variable, derivation_key) for derivation_key in derivation_keys if derivation_key.variable in monomial_cnt)
+
+ # for variable_candidate, derivative_key in variable_candidates:
+
+ # def generate_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is variable_candidate:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivative_key is not None:
+ # yield state.offset_dict[derivative_key][0]
+
+ # col_monomial = tuple(generate_monomial())
+
+ # terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ # if 0 < len(terms_row_col):
+ # aux_der_terms.append(dict(terms_row_col))
+
+ # state = dataclasses.replace(
+ # state,
+ # auxillary_equations=state.auxillary_equations + tuple(aux_der_terms),
+ # )
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=self.shape,
+ # )
+
+ # return state, poly_matrix