summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/mixins/traceexprmixin.py
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/mixins/traceexprmixin.py')
-rw-r--r--polymatrix/expression/mixins/traceexprmixin.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/polymatrix/expression/mixins/traceexprmixin.py b/polymatrix/expression/mixins/traceexprmixin.py
new file mode 100644
index 0000000..fa46967
--- /dev/null
+++ b/polymatrix/expression/mixins/traceexprmixin.py
@@ -0,0 +1,35 @@
+import abc
+import dataclass_abc
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+
+class TraceExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TracePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(row, row)
+
+ return state, TracePolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], 1),
+ ) \ No newline at end of file