summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-12 09:30:35 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-04-12 09:30:35 +0200
commitf9ec423683462800c757119c08947b742458639e (patch)
treec75a4686696368b160a46ed4f67601be8581492f /polymatrix/expression/mixins/derivativeexprmixin.py
parentremove shape property, introduce accumulate function, reimplement derivative ... (diff)
downloadpolymatrix-f9ec423683462800c757119c08947b742458639e.tar.gz
polymatrix-f9ec423683462800c757119c08947b742458639e.zip
bug fixes and clean ups
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py47
1 files changed, 25 insertions, 22 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 6098aa9..43c0efe 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -11,7 +11,8 @@ from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.polymatrix import PolyMatrix
-from polymatrix.polymatrixexprstate import PolyMatrixExprState
+from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
class DerivativeExprMixin(ExpressionBaseMixin):
@@ -45,39 +46,41 @@ class DerivativeExprMixin(ExpressionBaseMixin):
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
- state: PolyMatrixExprState,
- ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
- match self.variables:
- case ExpressionBaseMixin():
- state, variables = self.variables.apply(state)
+ # match self.variables:
+ # case ExpressionBaseMixin():
+ # state, variables = self.variables.apply(state)
- assert variables.shape[1] == 1
+ # assert variables.shape[1] == 1
- def gen_indices():
- for row in range(variables.shape[0]):
- row_terms = variables.get_poly(row, 0)
+ # def gen_indices():
+ # for row in range(variables.shape[0]):
+ # row_terms = variables.get_poly(row, 0)
- assert len(row_terms) == 1
+ # assert len(row_terms) == 1
- for monomial in row_terms.keys():
- assert len(monomial) == 1
- yield monomial[0]
+ # for monomial in row_terms.keys():
+ # assert len(monomial) == 1
+ # yield monomial[0]
- diff_wrt_variables = tuple(gen_indices())
+ # diff_wrt_variables = tuple(gen_indices())
- case _:
- def gen_indices():
- for variable in self.variables:
- if variable in state.offset_dict:
- yield state.offset_dict[variable][0]
+ # case _:
+ # def gen_indices():
+ # for variable in self.variables:
+ # if variable in state.offset_dict:
+ # yield state.offset_dict[variable][0]
- diff_wrt_variables = tuple(gen_indices())
+ # diff_wrt_variables = tuple(gen_indices())
+
+ state, diff_wrt_variables = get_variable_indices(self.variables, state)
def get_derivative_terms(
monomial_terms,
diff_wrt_variable: int,
- state: PolyMatrixExprState,
+ state: ExpressionState,
considered_variables: set,
):