summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/diagexpr.py4
-rw-r--r--polymatrix/expression/impl/diagexprimpl.py8
-rw-r--r--polymatrix/expression/init/initdiagexpr.py10
-rw-r--r--polymatrix/expression/mixins/diagexprmixin.py35
4 files changed, 57 insertions, 0 deletions
diff --git a/polymatrix/expression/diagexpr.py b/polymatrix/expression/diagexpr.py
new file mode 100644
index 0000000..1fcf14e
--- /dev/null
+++ b/polymatrix/expression/diagexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.diagexprmixin import DiagExprMixin
+
+class DiagExpr(DiagExprMixin):
+ pass
diff --git a/polymatrix/expression/impl/diagexprimpl.py b/polymatrix/expression/impl/diagexprimpl.py
new file mode 100644
index 0000000..950f2b8
--- /dev/null
+++ b/polymatrix/expression/impl/diagexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.diagexpr import DiagExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class DiagExprImpl(DiagExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init/initdiagexpr.py b/polymatrix/expression/init/initdiagexpr.py
new file mode 100644
index 0000000..db9b6d4
--- /dev/null
+++ b/polymatrix/expression/init/initdiagexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.diagexprimpl import DiagExprImpl
+
+
+def init_diag_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return DiagExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py
new file mode 100644
index 0000000..3cf8573
--- /dev/null
+++ b/polymatrix/expression/mixins/diagexprmixin.py
@@ -0,0 +1,35 @@
+import abc
+import dataclass_abc
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+
+class DiagExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TracePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(row, row)
+
+ return state, TracePolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], 1),
+ ) \ No newline at end of file