diff options
Diffstat (limited to 'polymatrix/expression/mixins/linearinexprmixin.py')
-rw-r--r-- | polymatrix/expression/mixins/linearinexprmixin.py | 44 |
1 files changed, 24 insertions, 20 deletions
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index a900b7f..8354754 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -1,4 +1,3 @@ - import abc import collections @@ -7,7 +6,9 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState from polymatrix.expression.utils.getmonomialindices import get_monomial_indices -from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable +from polymatrix.expression.utils.getvariableindices import ( + get_variable_indices_from_variable, +) class LinearInExprMixin(ExpressionBaseMixin): @@ -33,48 +34,51 @@ class LinearInExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def underlying(self) -> ExpressionBaseMixin: - ... + def underlying(self) -> ExpressionBaseMixin: ... @property @abc.abstractmethod - def monomials(self) -> ExpressionBaseMixin: - ... + def monomials(self) -> ExpressionBaseMixin: ... @property @abc.abstractmethod - def variables(self) -> ExpressionBaseMixin: - ... + def variables(self) -> ExpressionBaseMixin: ... @property @abc.abstractmethod - def ignore_unmatched(self) -> bool: - ... + def ignore_unmatched(self) -> bool: ... # overwrites the abstract method of `ExpressionBaseMixin` def apply( - self, + self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - state, underlying = self.underlying.apply(state=state) state, monomials = get_monomial_indices(state, self.monomials) - state, variable_indices = get_variable_indices_from_variable(state, self.variables) + state, variable_indices = get_variable_indices_from_variable( + state, self.variables + ) assert underlying.shape[1] == 1 poly_matrix_data = collections.defaultdict(dict) for row in range(underlying.shape[0]): - polynomial = underlying.get_poly(row, 0) if polynomial is None: continue for monomial, value in polynomial.items(): - - x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices) - p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices) + x_monomial = tuple( + (var_idx, count) + for var_idx, count in monomial + if var_idx in variable_indices + ) + p_monomial = tuple( + (var_idx, count) + for var_idx, count in monomial + if var_idx not in variable_indices + ) try: col = monomials.index(x_monomial) @@ -82,13 +86,13 @@ class LinearInExprMixin(ExpressionBaseMixin): if self.ignore_unmatched: continue else: - raise Exception(f'{x_monomial} not in {monomials}') + raise Exception(f"{x_monomial} not in {monomials}") poly_matrix_data[row, col][p_monomial] = value - + poly_matrix = init_poly_matrix( data=dict(poly_matrix_data), shape=(underlying.shape[0], len(monomials)), ) - return state, poly_matrix + return state, poly_matrix |