diff options
author | Nao Pross <np@0hm.ch> | 2024-05-01 22:29:33 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-01 22:39:16 +0200 |
commit | 7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c (patch) | |
tree | 89dede6e589ef1b74d0288272fa4056242815cd8 | |
parent | Minor changes to PolyMatrixAsAffineExpression (diff) | |
download | polymatrix-7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c.tar.gz polymatrix-7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c.zip |
Replace VariableMixin with Variable & VariableExpression
-rw-r--r-- | polymatrix/expression/expression.py | 66 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 10 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/mixins/variablemixin.py | 45 | ||||
-rw-r--r-- | polymatrix/expressionstate/mixins.py | 16 | ||||
-rw-r--r-- | polymatrix/variable/__init__.py | 0 | ||||
-rw-r--r-- | polymatrix/variable/abc.py | 13 | ||||
-rw-r--r-- | polymatrix/variable/impl.py | 8 | ||||
-rw-r--r-- | polymatrix/variable/init.py | 5 |
10 files changed, 92 insertions, 82 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 4ef955d..0e15f44 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -1,19 +1,25 @@ from __future__ import annotations -import abc import dataclasses -import dataclassabc import typing +import itertools import numpy as np +from abc import ABC, abstractmethod +from dataclassabc import dataclassabc +from typing_extensions import override + if typing.TYPE_CHECKING: from polymatrix.expressionstate.abc import ExpressionState +from polymatrix.variable.abc import Variable import polymatrix.expression.init from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.op import ( diff, @@ -25,12 +31,9 @@ from polymatrix.expression.op import ( degree, ) -class Expression( - ExpressionBaseMixin, - abc.ABC, -): +class Expression(ExpressionBaseMixin, ABC): @property - @abc.abstractmethod + @abstractmethod def underlying(self) -> ExpressionBaseMixin: # NP: I know what it is but it was very confusing in the beginning. # FIXME: provide documentation on how underlying works / what it means @@ -67,7 +70,7 @@ class Expression( ) def __matmul__( - self, other: typing.Union[ExpressionBaseMixin, np.ndarray] + self, other: ExpressionBaseMixin | np.ndarray ) -> "Expression": return self._binary( polymatrix.expression.init.init_matrix_mult_expr, self, other @@ -109,7 +112,7 @@ class Expression( def __truediv__(self, other: ExpressionBaseMixin): return self._binary(polymatrix.expression.init.init_division_expr, self, other) - @abc.abstractmethod + @abstractmethod def copy(self, underlying: ExpressionBaseMixin) -> "Expression": ... @staticmethod @@ -500,7 +503,7 @@ class Expression( # This is here and not in impl.py because of circular imports # FIXME: this is not ideal -@dataclassabc.dataclassabc(frozen=True, repr=False) +@dataclassabc(frozen=True, repr=False) class ExpressionImpl(Expression): underlying: ExpressionBaseMixin @@ -523,3 +526,46 @@ def init_expression( return ExpressionImpl( underlying=underlying, ) + + +class VariableExpression(Expression, Variable): + """ + Expression that is a polynomial variable, i.e. an expression that cannot be + reduced further. + """ + @property + @override + def underlying(self): + return None + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + # Since there is no underlying this class directly does the job of + # creating a polymatrix + + state, indices = state.index(self) + p = PolyMatrixDict() + + rows, cols = self.shape + for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices): + p[row, col] = PolyDict({ + # Create monomial with variable to the first power + # with coefficient of one + MonomialIndex((VariableIndex(index, power=1),)): 1. + }) + + return state, init_poly_matrix(p, self.shape) + + +@dataclassabc(frozen=True) +class VariableExpressionImpl(VariableExpression): + name: str + shape: tuple[int, int] + + @override + def copy(self, underlying: ExpressionBaseMixin) -> "Expression": + return self + + +def init_variable_expression(name: str, shape: tuple[int, int] = (1, 1)): + return VariableExpressionImpl(name, shape) diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index af11d49..378a4b7 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -5,10 +5,8 @@ from polymatrix.expression.typing import FromDataTypes import polymatrix.expression.init -from polymatrix.expression.expression import init_expression, Expression +from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.mixins.variablemixin import VariableMixin -from polymatrix.expression.init import init_variable_expr from polymatrix.statemonad.abc import StateMonad @@ -37,10 +35,10 @@ def from_( ) -def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableMixin] | VariableMixin: +def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableExpression] | VariableExpression: """ Construct one or multiple variables from comma separated a list of names. """ - variables = tuple(init_expression(underlying=init_variable_expr(name.strip(), shape)) - for name in names.split(",")) + variables = tuple(init_variable_expression(name.strip(), shape) + for name in names.split(",")) if len(variables) == 1: return variables[0] diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index d4f5429..7e23329 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -68,7 +68,6 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ( ) from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin -from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin from polymatrix.expression.mixins.tosortedvariablesmixin import ( ToSortedVariablesExprMixin, @@ -400,11 +399,5 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) -class VariableImpl(VariableMixin): - name: str - shape: tuple[int, int] - - -@dataclassabc.dataclassabc(frozen=True) class VStackExprImpl(VStackExprMixin): underlying: tuple diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 83ee26f..0e6faf9 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -529,7 +529,3 @@ def init_truncate_expr( degrees=degrees, inverse=inverse, ) - - -def init_variable_expr(name: str, shape: tuple[int, int] = (1, 1)): - return polymatrix.expression.impl.VariableImpl(name, shape) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py deleted file mode 100644 index da72b65..0000000 --- a/polymatrix/expression/mixins/variablemixin.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from itertools import product -from typing_extensions import override -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from polymatrix.expressionstate.mixins import ExpressionStateMixin - -from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex -from polymatrix.polymatrix.init import init_poly_matrix - -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin - - -class VariableMixin(ExpressionBaseMixin): - """ Mixin to describe a polynomial variable, i.e. an expression that cannot - be reduced further. """ - - @property - @abstractmethod - def name(self) -> str: - """ Name of the variable. """ - - @property - @abstractmethod - def shape(self) -> tuple[int, int]: - """ Shape of the variable. """ - - @override - def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - state, indices = state.index(self) - p = PolyMatrixDict() - - rows, cols = self.shape - for (row, col), index in zip(product(range(rows), range(cols)), indices): - p[row, col] = PolyDict({ - # Create monomial with variable to the first power - # with coefficient of one - MonomialIndex((VariableIndex(index, power=1),)): 1. - }) - - return state, init_poly_matrix(p, self.shape) diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py index f937137..836032e 100644 --- a/polymatrix/expressionstate/mixins.py +++ b/polymatrix/expressionstate/mixins.py @@ -7,7 +7,7 @@ from math import prod import dataclasses from polymatrix.expression.expression import Expression -from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.variable.abc import Variable from polymatrix.statemonad.mixins import StateCacheMixin @@ -30,18 +30,14 @@ class ExpressionStateMixin( @property @abstractmethod - def indices(self) -> dict[VariableMixin, IndexRange]: + def indices(self) -> dict[Variable, IndexRange]: """ Map from variable objects to their indices. """ - def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]: + def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]: """ Get the index of a variable. """ # Unwrap if wrapped in expression object - if isinstance(var, Expression): - var = var.underlying - - if not isinstance(var, VariableMixin): - raise ValueError("State can only index object of type VariableMixin or expressions " - "that contain a single VariableMixin object!") + if not isinstance(var, Variable): + raise ValueError("State can only index object of type Variable!") # Check if already in there if var in self.indices.keys(): @@ -62,7 +58,7 @@ class ExpressionStateMixin( indices={**self.indices, var: index} ), index - def register(self, var: VariableMixin) -> ExpressionStateMixin: + def register(self, var: Variable) -> ExpressionStateMixin: """ Create an index for a variable, but does not return the index. If you want the index use diff --git a/polymatrix/variable/__init__.py b/polymatrix/variable/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/polymatrix/variable/__init__.py diff --git a/polymatrix/variable/abc.py b/polymatrix/variable/abc.py new file mode 100644 index 0000000..70ad8de --- /dev/null +++ b/polymatrix/variable/abc.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +class Variable(ABC): + @property + @abstractmethod + def name(self) -> str: + """ Name of the variable. """ + + @property + @abstractmethod + def shape(self) -> tuple[int, int]: + """ Shape of the variable. """ + diff --git a/polymatrix/variable/impl.py b/polymatrix/variable/impl.py new file mode 100644 index 0000000..47e2041 --- /dev/null +++ b/polymatrix/variable/impl.py @@ -0,0 +1,8 @@ +from dataclassabc import dataclassabc + +from .abc import Variable + +@dataclassabc(frozen=True) +class VariableImpl(Variable): + name: str + shape: tuple[int, int] diff --git a/polymatrix/variable/init.py b/polymatrix/variable/init.py new file mode 100644 index 0000000..de217bd --- /dev/null +++ b/polymatrix/variable/init.py @@ -0,0 +1,5 @@ +from .abc import Variable +from .impl import VariableImpl + +def init_variable(name: str, shape: tuple[int, int]) -> Variable: + return VariableImpl(str, shape) |