diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-12 09:30:35 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-04-12 09:30:35 +0200 |
commit | f9ec423683462800c757119c08947b742458639e (patch) | |
tree | c75a4686696368b160a46ed4f67601be8581492f /polymatrix/expression/mixins/derivativeexprmixin.py | |
parent | remove shape property, introduce accumulate function, reimplement derivative ... (diff) | |
download | polymatrix-f9ec423683462800c757119c08947b742458639e.tar.gz polymatrix-f9ec423683462800c757119c08947b742458639e.zip |
bug fixes and clean ups
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/derivativeexprmixin.py | 47 |
1 files changed, 25 insertions, 22 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 6098aa9..43c0efe 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -11,7 +11,8 @@ from polymatrix.expression.init.initderivativekey import init_derivative_key from polymatrix.expression.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.polymatrix import PolyMatrix -from polymatrix.polymatrixexprstate import PolyMatrixExprState +from polymatrix.expression.expressionstate import ExpressionState +from polymatrix.expression.utils.getvariableindices import get_variable_indices class DerivativeExprMixin(ExpressionBaseMixin): @@ -45,39 +46,41 @@ class DerivativeExprMixin(ExpressionBaseMixin): # overwrites abstract method of `ExpressionBaseMixin` def apply( self, - state: PolyMatrixExprState, - ) -> tuple[PolyMatrixExprState, PolyMatrix]: + state: ExpressionState, + ) -> tuple[ExpressionState, PolyMatrix]: - match self.variables: - case ExpressionBaseMixin(): - state, variables = self.variables.apply(state) + # match self.variables: + # case ExpressionBaseMixin(): + # state, variables = self.variables.apply(state) - assert variables.shape[1] == 1 + # assert variables.shape[1] == 1 - def gen_indices(): - for row in range(variables.shape[0]): - row_terms = variables.get_poly(row, 0) + # def gen_indices(): + # for row in range(variables.shape[0]): + # row_terms = variables.get_poly(row, 0) - assert len(row_terms) == 1 + # assert len(row_terms) == 1 - for monomial in row_terms.keys(): - assert len(monomial) == 1 - yield monomial[0] + # for monomial in row_terms.keys(): + # assert len(monomial) == 1 + # yield monomial[0] - diff_wrt_variables = tuple(gen_indices()) + # diff_wrt_variables = tuple(gen_indices()) - case _: - def gen_indices(): - for variable in self.variables: - if variable in state.offset_dict: - yield state.offset_dict[variable][0] + # case _: + # def gen_indices(): + # for variable in self.variables: + # if variable in state.offset_dict: + # yield state.offset_dict[variable][0] - diff_wrt_variables = tuple(gen_indices()) + # diff_wrt_variables = tuple(gen_indices()) + + state, diff_wrt_variables = get_variable_indices(self.variables, state) def get_derivative_terms( monomial_terms, diff_wrt_variable: int, - state: PolyMatrixExprState, + state: ExpressionState, considered_variables: set, ): |