summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/impl.py7
-rw-r--r--polymatrix/expression/init.py4
-rw-r--r--polymatrix/expression/mixins/variablemixin.py42
3 files changed, 53 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 32f2552..c921caa 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -64,6 +64,7 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import (
)
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
+from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
from polymatrix.expression.mixins.tosortedvariablesmixin import (
ToSortedVariablesExprMixin,
@@ -390,5 +391,11 @@ class TruncateExprImpl(TruncateExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class VariableImpl(VariableMixin):
+ name: str
+ shape: tuple[int, int]
+
+
+@dataclassabc.dataclassabc(frozen=True)
class VStackExprImpl(VStackExprMixin):
underlying: tuple
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 0b4ced9..4fc401e 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -524,3 +524,7 @@ def init_truncate_expr(
degrees=degrees,
inverse=inverse,
)
+
+
+def init_variable_expr(name: str, shape: tuple[int, int]):
+ return polymatrix.expression.impl.VariableImpl(name, shape)
diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py
new file mode 100644
index 0000000..ea1d222
--- /dev/null
+++ b/polymatrix/expression/mixins/variablemixin.py
@@ -0,0 +1,42 @@
+from abc import abstractmethod
+from itertools import product
+from typing_extensions import override
+
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
+from polymatrix.polymatrix.init import init_poly_matrix
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate.abc import ExpressionState
+
+
+class VariableMixin(ExpressionBaseMixin):
+ """ Mixin to describe a polynomial variable, i.e. an expression that cannot
+ be reduced further. """
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """ Name of the variable. """
+
+ @property
+ @abstractmethod
+ def shape(self) -> tuple[int, int]:
+ """ Shape of the variable. """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ """ See :py:meth:`ExpressionBaseMixin.apply`. """
+
+ state, indices = state.index(self)
+ p = PolyMatrixDict()
+
+ rows, cols = self.shape
+ for (row, col), index in zip(product(range(rows), range(cols)), indices):
+ p[row, col] = PolyDict({
+ # Create monomial with variable to the first power
+ # with coefficient of one
+ MonomialIndex((VariableIndex(index, power=1),)): 1.
+ })
+
+ return state, init_poly_matrix(p, self.shape)