diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-05-03 11:13:25 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-05-03 11:13:25 +0200 |
commit | a9aa5ea1b606a160596f684cc3c4d3b9a578f287 (patch) | |
tree | ea1993a9cac1b871321f75244a636599f6ec3203 /polymatrix/expression/mixins/derivativeexprmixin.py | |
parent | bug fixes and clean ups (diff) | |
download | polymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.tar.gz polymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.zip |
add statemonad syntax
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/derivativeexprmixin.py | 46 |
1 files changed, 3 insertions, 43 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 43c0efe..a551c1f 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -4,8 +4,6 @@ import collections import dataclasses import itertools import typing -import dataclass_abc -from numpy import nonzero, var from polymatrix.expression.init.initderivativekey import init_derivative_key from polymatrix.expression.init.initpolymatrix import init_poly_matrix @@ -31,51 +29,15 @@ class DerivativeExprMixin(ExpressionBaseMixin): def introduce_derivatives(self) -> bool: ... - # # overwrites abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # match self.variables: - # case ExpressionBaseMixin(): - # n_cols = self.variables.shape[0] - - # case _: - # n_cols = len(self.variables) - - # return self.underlying.shape[0], n_cols - # overwrites abstract method of `ExpressionBaseMixin` - def apply( + def _apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - # match self.variables: - # case ExpressionBaseMixin(): - # state, variables = self.variables.apply(state) - - # assert variables.shape[1] == 1 - - # def gen_indices(): - # for row in range(variables.shape[0]): - # row_terms = variables.get_poly(row, 0) - - # assert len(row_terms) == 1 - - # for monomial in row_terms.keys(): - # assert len(monomial) == 1 - # yield monomial[0] - - # diff_wrt_variables = tuple(gen_indices()) - - # case _: - # def gen_indices(): - # for variable in self.variables: - # if variable in state.offset_dict: - # yield state.offset_dict[variable][0] - - # diff_wrt_variables = tuple(gen_indices()) + state, underlying = self.underlying.apply(state=state) - state, diff_wrt_variables = get_variable_indices(self.variables, state) + state, diff_wrt_variables = get_variable_indices(state, self.variables) def get_derivative_terms( monomial_terms, @@ -181,8 +143,6 @@ class DerivativeExprMixin(ExpressionBaseMixin): return state, dict(derivation_terms) - state, underlying = self.underlying.apply(state=state) - terms = {} for row in range(underlying.shape[0]): |