diff options
-rw-r--r-- | polymatrix/expression/impl/truncateexprimpl.py | 1 | ||||
-rw-r--r-- | polymatrix/expression/init/inittruncateexpr.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/mixins/truncateexprmixin.py | 8 |
3 files changed, 13 insertions, 3 deletions
diff --git a/polymatrix/expression/impl/truncateexprimpl.py b/polymatrix/expression/impl/truncateexprimpl.py index e4d10ae..5bd45f3 100644 --- a/polymatrix/expression/impl/truncateexprimpl.py +++ b/polymatrix/expression/impl/truncateexprimpl.py @@ -8,3 +8,4 @@ class TruncateExprImpl(TruncateExpr): underlying: ExpressionBaseMixin variables: ExpressionBaseMixin degrees: tuple[int] + inverse: bool diff --git a/polymatrix/expression/init/inittruncateexpr.py b/polymatrix/expression/init/inittruncateexpr.py index 060ebaf..4582bcb 100644 --- a/polymatrix/expression/init/inittruncateexpr.py +++ b/polymatrix/expression/init/inittruncateexpr.py @@ -6,12 +6,17 @@ def init_truncate_expr( underlying: ExpressionBaseMixin, variables: ExpressionBaseMixin, degrees: tuple[int], + inverse: bool = None, ): if isinstance(degrees, int): degrees = (degrees,) + if inverse is None: + inverse = False + return TruncateExprImpl( underlying=underlying, variables=variables, degrees=degrees, -) + inverse=inverse, + ) diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 2421f48..42f192e 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -8,7 +8,6 @@ from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import get_variable_indices -# replace by filter operation? class TruncateExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod @@ -25,6 +24,11 @@ class TruncateExprMixin(ExpressionBaseMixin): def degrees(self) -> tuple[int]: ... + @property + @abc.abstractmethod + def inverse(self) -> bool: + ... + # overwrites abstract method of `ExpressionBaseMixin` def apply( self, @@ -49,7 +53,7 @@ class TruncateExprMixin(ExpressionBaseMixin): degree = sum((count for var_idx, count in monomial if var_idx in variable_indices)) - if degree in self.degrees: + if (degree in self.degrees) is not self.inverse: terms_row_col[monomial] = value terms[row, col] = terms_row_col |