summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/derivativeexprmixin.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-05-03 11:13:25 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-05-03 11:13:25 +0200
commita9aa5ea1b606a160596f684cc3c4d3b9a578f287 (patch)
treeea1993a9cac1b871321f75244a636599f6ec3203 /polymatrix/expression/mixins/derivativeexprmixin.py
parentbug fixes and clean ups (diff)
downloadpolymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.tar.gz
polymatrix-a9aa5ea1b606a160596f684cc3c4d3b9a578f287.zip
add statemonad syntax
Diffstat (limited to 'polymatrix/expression/mixins/derivativeexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py46
1 files changed, 3 insertions, 43 deletions
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 43c0efe..a551c1f 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -4,8 +4,6 @@ import collections
import dataclasses
import itertools
import typing
-import dataclass_abc
-from numpy import nonzero, var
from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
@@ -31,51 +29,15 @@ class DerivativeExprMixin(ExpressionBaseMixin):
def introduce_derivatives(self) -> bool:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # match self.variables:
- # case ExpressionBaseMixin():
- # n_cols = self.variables.shape[0]
-
- # case _:
- # n_cols = len(self.variables)
-
- # return self.underlying.shape[0], n_cols
-
# overwrites abstract method of `ExpressionBaseMixin`
- def apply(
+ def _apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- # match self.variables:
- # case ExpressionBaseMixin():
- # state, variables = self.variables.apply(state)
-
- # assert variables.shape[1] == 1
-
- # def gen_indices():
- # for row in range(variables.shape[0]):
- # row_terms = variables.get_poly(row, 0)
-
- # assert len(row_terms) == 1
-
- # for monomial in row_terms.keys():
- # assert len(monomial) == 1
- # yield monomial[0]
-
- # diff_wrt_variables = tuple(gen_indices())
-
- # case _:
- # def gen_indices():
- # for variable in self.variables:
- # if variable in state.offset_dict:
- # yield state.offset_dict[variable][0]
-
- # diff_wrt_variables = tuple(gen_indices())
+ state, underlying = self.underlying.apply(state=state)
- state, diff_wrt_variables = get_variable_indices(self.variables, state)
+ state, diff_wrt_variables = get_variable_indices(state, self.variables)
def get_derivative_terms(
monomial_terms,
@@ -181,8 +143,6 @@ class DerivativeExprMixin(ExpressionBaseMixin):
return state, dict(derivation_terms)
- state, underlying = self.underlying.apply(state=state)
-
terms = {}
for row in range(underlying.shape[0]):