diff options
-rw-r--r-- | polymatrix/expression/impl.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/mixins/variablemixin.py | 42 |
3 files changed, 53 insertions, 0 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 32f2552..c921caa 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -64,6 +64,7 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ( ) from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin from polymatrix.expression.mixins.tosortedvariablesmixin import ( ToSortedVariablesExprMixin, @@ -390,5 +391,11 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) +class VariableImpl(VariableMixin): + name: str + shape: tuple[int, int] + + +@dataclassabc.dataclassabc(frozen=True) class VStackExprImpl(VStackExprMixin): underlying: tuple diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 0b4ced9..4fc401e 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -524,3 +524,7 @@ def init_truncate_expr( degrees=degrees, inverse=inverse, ) + + +def init_variable_expr(name: str, shape: tuple[int, int]): + return polymatrix.expression.impl.VariableImpl(name, shape) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py new file mode 100644 index 0000000..ea1d222 --- /dev/null +++ b/polymatrix/expression/mixins/variablemixin.py @@ -0,0 +1,42 @@ +from abc import abstractmethod +from itertools import product +from typing_extensions import override + +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex +from polymatrix.polymatrix.init import init_poly_matrix + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate.abc import ExpressionState + + +class VariableMixin(ExpressionBaseMixin): + """ Mixin to describe a polynomial variable, i.e. an expression that cannot + be reduced further. """ + + @property + @abstractmethod + def name(self) -> str: + """ Name of the variable. """ + + @property + @abstractmethod + def shape(self) -> tuple[int, int]: + """ Shape of the variable. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + """ See :py:meth:`ExpressionBaseMixin.apply`. """ + + state, indices = state.index(self) + p = PolyMatrixDict() + + rows, cols = self.shape + for (row, col), index in zip(product(range(rows), range(cols)), indices): + p[row, col] = PolyDict({ + # Create monomial with variable to the first power + # with coefficient of one + MonomialIndex((VariableIndex(index, power=1),)): 1. + }) + + return state, init_poly_matrix(p, self.shape) |