summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-05 10:52:22 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-05 10:52:22 +0200
commite9641d1e8aae194dbee4f294d11131cefb595783 (patch)
tree9e7d86ac4104c784cc64d74a08858aa821890bcb
parentadd 'install_requires' to setup.py, add missing link to README.md (diff)
downloadpolymatrix-e9641d1e8aae194dbee4f294d11131cefb595783.tar.gz
polymatrix-e9641d1e8aae194dbee4f294d11131cefb595783.zip
add trace operator
-rw-r--r--polymatrix/__init__.py1
-rw-r--r--polymatrix/expression/impl/traceexprimpl.py8
-rw-r--r--polymatrix/expression/init/inittraceexpr.py10
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py9
-rw-r--r--polymatrix/expression/mixins/traceexprmixin.py35
-rw-r--r--polymatrix/expression/traceexpr.py4
6 files changed, 66 insertions, 1 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 4df9a81..1c072b4 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -20,7 +20,6 @@ from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
from polymatrix.expression.utils.monomialtoindex import monomial_to_index
from polymatrix.expressionstate.init.initexpressionstate import init_expression_state as original_init_expression_state
-
def init_expression_state():
return original_init_expression_state()
diff --git a/polymatrix/expression/impl/traceexprimpl.py b/polymatrix/expression/impl/traceexprimpl.py
new file mode 100644
index 0000000..19efeeb
--- /dev/null
+++ b/polymatrix/expression/impl/traceexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.traceexpr import TraceExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class TraceExprImpl(TraceExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init/inittraceexpr.py b/polymatrix/expression/init/inittraceexpr.py
new file mode 100644
index 0000000..9069ee6
--- /dev/null
+++ b/polymatrix/expression/init/inittraceexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.traceexprimpl import TraceExprImpl
+
+
+def init_trace_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return TraceExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index 4afa44f..43b1e41 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -33,6 +33,7 @@ from polymatrix.expression.init.initsumexpr import init_sum_expr
from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr
from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr
+from polymatrix.expression.init.inittraceexpr import init_trace_expr
from polymatrix.expression.init.inittransposeexpr import init_transpose_expr
from polymatrix.expression.init.inittruncateexpr import init_truncate_expr
@@ -446,6 +447,14 @@ class ExpressionMixin(
),
)
+ def trace(self):
+ return dataclasses.replace(
+ self,
+ underlying=init_trace_expr(
+ underlying=self.underlying,
+ ),
+ )
+
def truncate(self, variables: tuple, degrees: tuple[int]):
return dataclasses.replace(
self,
diff --git a/polymatrix/expression/mixins/traceexprmixin.py b/polymatrix/expression/mixins/traceexprmixin.py
new file mode 100644
index 0000000..fa46967
--- /dev/null
+++ b/polymatrix/expression/mixins/traceexprmixin.py
@@ -0,0 +1,35 @@
+import abc
+import dataclass_abc
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+
+class TraceExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ assert underlying.shape[0] == underlying.shape[1], f'{underlying.shape=}'
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TracePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(row, row)
+
+ return state, TracePolyMatrix(
+ underlying=underlying,
+ shape=(underlying.shape[0], 1),
+ ) \ No newline at end of file
diff --git a/polymatrix/expression/traceexpr.py b/polymatrix/expression/traceexpr.py
new file mode 100644
index 0000000..f35bbca
--- /dev/null
+++ b/polymatrix/expression/traceexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.traceexprmixin import TraceExprMixin
+
+class TraceExpr(TraceExprMixin):
+ pass