summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/linearinexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/linearinexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py44
1 files changed, 24 insertions, 20 deletions
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index a900b7f..8354754 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -1,4 +1,3 @@
-
import abc
import collections
@@ -7,7 +6,9 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
-from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
+from polymatrix.expression.utils.getvariableindices import (
+ get_variable_indices_from_variable,
+)
class LinearInExprMixin(ExpressionBaseMixin):
@@ -33,48 +34,51 @@ class LinearInExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def underlying(self) -> ExpressionBaseMixin:
- ...
+ def underlying(self) -> ExpressionBaseMixin: ...
@property
@abc.abstractmethod
- def monomials(self) -> ExpressionBaseMixin:
- ...
+ def monomials(self) -> ExpressionBaseMixin: ...
@property
@abc.abstractmethod
- def variables(self) -> ExpressionBaseMixin:
- ...
+ def variables(self) -> ExpressionBaseMixin: ...
@property
@abc.abstractmethod
- def ignore_unmatched(self) -> bool:
- ...
+ def ignore_unmatched(self) -> bool: ...
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
- self,
+ self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
-
state, underlying = self.underlying.apply(state=state)
state, monomials = get_monomial_indices(state, self.monomials)
- state, variable_indices = get_variable_indices_from_variable(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(
+ state, self.variables
+ )
assert underlying.shape[1] == 1
poly_matrix_data = collections.defaultdict(dict)
for row in range(underlying.shape[0]):
-
polynomial = underlying.get_poly(row, 0)
if polynomial is None:
continue
for monomial, value in polynomial.items():
-
- x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices)
- p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices)
+ x_monomial = tuple(
+ (var_idx, count)
+ for var_idx, count in monomial
+ if var_idx in variable_indices
+ )
+ p_monomial = tuple(
+ (var_idx, count)
+ for var_idx, count in monomial
+ if var_idx not in variable_indices
+ )
try:
col = monomials.index(x_monomial)
@@ -82,13 +86,13 @@ class LinearInExprMixin(ExpressionBaseMixin):
if self.ignore_unmatched:
continue
else:
- raise Exception(f'{x_monomial} not in {monomials}')
+ raise Exception(f"{x_monomial} not in {monomials}")
poly_matrix_data[row, col][p_monomial] = value
-
+
poly_matrix = init_poly_matrix(
data=dict(poly_matrix_data),
shape=(underlying.shape[0], len(monomials)),
)
- return state, poly_matrix
+ return state, poly_matrix