diff options
-rw-r--r-- | polymatrix/__init__.py | 1 | ||||
-rw-r--r-- | polymatrix/expression/impl/traceexprimpl.py | 8 | ||||
-rw-r--r-- | polymatrix/expression/init/inittraceexpr.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/mixins/expressionmixin.py | 9 | ||||
-rw-r--r-- | polymatrix/expression/mixins/traceexprmixin.py | 35 | ||||
-rw-r--r-- | polymatrix/expression/traceexpr.py | 4 |
6 files changed, 66 insertions, 1 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 4df9a81..1c072b4 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -20,7 +20,6 @@ from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin from polymatrix.expression.utils.monomialtoindex import monomial_to_index from polymatrix.expressionstate.init.initexpressionstate import init_expression_state as original_init_expression_state - def init_expression_state(): return original_init_expression_state() diff --git a/polymatrix/expression/impl/traceexprimpl.py b/polymatrix/expression/impl/traceexprimpl.py new file mode 100644 index 0000000..19efeeb --- /dev/null +++ b/polymatrix/expression/impl/traceexprimpl.py @@ -0,0 +1,8 @@ +import dataclass_abc +from polymatrix.expression.traceexpr import TraceExpr + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin + +@dataclass_abc.dataclass_abc(frozen=True) +class TraceExprImpl(TraceExpr): + underlying: ExpressionBaseMixin diff --git a/polymatrix/expression/init/inittraceexpr.py b/polymatrix/expression/init/inittraceexpr.py new file mode 100644 index 0000000..9069ee6 --- /dev/null +++ b/polymatrix/expression/init/inittraceexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.traceexprimpl import TraceExprImpl + + +def init_trace_expr( + underlying: ExpressionBaseMixin, +): + return TraceExprImpl( + underlying=underlying, +) diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index 4afa44f..43b1e41 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -33,6 +33,7 @@ from polymatrix.expression.init.initsumexpr import init_sum_expr from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr +from polymatrix.expression.init.inittraceexpr import init_trace_expr from polymatrix.expression.init.inittransposeexpr import init_transpose_expr from polymatrix.expression.init.inittruncateexpr import init_truncate_expr @@ -446,6 +447,14 @@ class ExpressionMixin( ), ) + def trace(self): + return dataclasses.replace( + self, + underlying=init_trace_expr( + underlying=self.underlying, + ), + ) + def truncate(self, variables: tuple, degrees: tuple[int]): return dataclasses.replace( self, diff --git a/polymatrix/expression/mixins/traceexprmixin.py b/polymatrix/expression/mixins/traceexprmixin.py new file mode 100644 index 0000000..fa46967 --- /dev/null +++ b/polymatrix/expression/mixins/traceexprmixin.py @@ -0,0 +1,35 @@ +import abc +import dataclass_abc + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin +from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin + +class TraceExprMixin(ExpressionBaseMixin): + @property + @abc.abstractclassmethod + def underlying(self) -> ExpressionBaseMixin: + ... + + # overwrites abstract method of `ExpressionBaseMixin` + def apply( + self, + state: ExpressionStateMixin, + ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + + state, underlying = self.underlying.apply(state) + + assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}' + + @dataclass_abc.dataclass_abc(frozen=True) + class TracePolyMatrix(PolyMatrixMixin): + underlying: PolyMatrixMixin + shape: tuple[int, int] + + def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]: + return self.underlying.get_poly(row, row) + + return state, TracePolyMatrix( + underlying=underlying, + shape=(underlying.shape[0], 1), + )
\ No newline at end of file diff --git a/polymatrix/expression/traceexpr.py b/polymatrix/expression/traceexpr.py new file mode 100644 index 0000000..f35bbca --- /dev/null +++ b/polymatrix/expression/traceexpr.py @@ -0,0 +1,4 @@ +from polymatrix.expression.mixins.traceexprmixin import TraceExprMixin + +class TraceExpr(TraceExprMixin): + pass |