diff options
author | Nao Pross <np@0hm.ch> | 2024-05-02 11:19:00 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-02 16:57:53 +0200 |
commit | dab56f94d084226afd8f60c83a8e583158046081 (patch) | |
tree | 56cb8ebf97444f4c6fcc824d7e20123142d81ef6 | |
parent | Delete DeterminantExpr, DivisionExpr, ToQuadraticExpr (diff) | |
download | polymatrix-dab56f94d084226afd8f60c83a8e583158046081.tar.gz polymatrix-dab56f94d084226afd8f60c83a8e583158046081.zip |
Fix incorrect implementation of VariableExpression
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/expression.py | 51 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 5 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 8 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 8 | ||||
-rw-r--r-- | polymatrix/expression/mixins/variablemixin.py | 34 |
5 files changed, 74 insertions, 32 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 9d910db..a7e2025 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -2,7 +2,6 @@ from __future__ import annotations import dataclasses import typing -import itertools import numpy as np from abc import ABC, abstractmethod @@ -18,9 +17,8 @@ import polymatrix.expression.init from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.op import ( diff, integrate, @@ -508,7 +506,7 @@ class ExpressionImpl(Expression): def __repr__(self) -> str: return self.underlying.__repr__() - def copy(self, underlying: ExpressionBaseMixin) -> "Expression": + def copy(self, underlying: ExpressionBaseMixin) -> Expression: return dataclasses.replace( self, underlying=underlying, @@ -528,39 +526,32 @@ class VariableExpression(Expression, Variable): Expression that is a polynomial variable, i.e. an expression that cannot be reduced further. """ - @property - @override - def underlying(self): - return None - - @override - def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: - # Since there is no underlying this class directly does the job of - # creating a polymatrix - state, indices = state.index(self) - p = PolyMatrixDict() + def __hash__(self): + return self.underlying.__hash__() - rows, cols = self.shape - for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices): - p[row, col] = PolyDict({ - # Create monomial with variable to the first power - # with coefficient of one - MonomialIndex((VariableIndex(index, power=1),)): 1. - }) + @override + @property + def name(self): + return self.underlyng.name - return state, init_poly_matrix(p, self.shape) + @override + @property + def shape(self): + return self.underlyng.shape @dataclassabc(frozen=True) class VariableExpressionImpl(VariableExpression): - name: str - shape: tuple[int, int] + underlying: ExpressionBaseMixin + + def __repr__(self) -> str: + return self.underlying.__repr__() + + def copy(self, underlying: ExpressionBaseMixin) -> Expression: + return init_expression(underlying) - @override - def copy(self, underlying: ExpressionBaseMixin) -> "Expression": - return self +def init_variable_expression(underlying: VariableMixin): + return VariableExpressionImpl(underlying) -def init_variable_expression(name: str, shape: tuple[int, int] = (1, 1)): - return VariableExpressionImpl(name, shape) diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 6fdb672..a1303a3 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -35,8 +35,9 @@ def from_( def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableExpression] | VariableExpression: """ Construct one or multiple variables from comma separated a list of names. """ - variables = tuple(init_variable_expression(name.strip(), shape) - for name in names.split(",")) + variables = tuple(init_variable_expression( + underlying=polymatrix.expression.init.init_variable_expr(name.strip(), shape)) + for name in names.split(",")) if len(variables) == 1: return variables[0] diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 8551fc5..119d8c7 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -1,5 +1,6 @@ import numpy.typing import sympy +from typing_extensions import override import dataclassabc from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin @@ -65,6 +66,7 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ( ) from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin +from polymatrix.expression.mixins.variablemixin import VariableMixin from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin from polymatrix.expression.mixins.tosortedvariablesmixin import ( ToSortedVariablesExprMixin, @@ -375,5 +377,11 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) +class VariableImpl(VariableMixin): + name: str + shape: tuple[int, int] + + +@dataclassabc.dataclassabc(frozen=True) class VStackExprImpl(VStackExprMixin): underlying: tuple diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 0e6faf9..131dd82 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -15,6 +15,7 @@ from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.utils.formatsubstitutions import format_substitutions from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl import AdditionExprImpl +from polymatrix.expression.expression import VariableExpression def init_addition_expr( @@ -174,6 +175,9 @@ def init_from_sympy_expr(data: sympy.Expr | sympy.Matrix | tuple[tuple[sympy.Exp def init_from_expr_or_none( data: FromDataTypes, ) -> ExpressionBaseMixin | None: + if isinstance(data, VariableExpression): + return data + if isinstance(data, str): return init_parametrize_expr( underlying=init_from_expr_or_none(1), # FIXME: typing @@ -529,3 +533,7 @@ def init_truncate_expr( degrees=degrees, inverse=inverse, ) + + +def init_variable_expr(name: str, shape: tuple[int, int]): + return polymatrix.expression.impl.VariableImpl(name, shape) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py new file mode 100644 index 0000000..01968e4 --- /dev/null +++ b/polymatrix/expression/mixins/variablemixin.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing +import itertools +from typing_extensions import override + +if typing.TYPE_CHECKING: + from polymatrix.expressionstate.mixins import ExpressionStateMixin + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex +from polymatrix.variable.abc import Variable + + +class VariableMixin(ExpressionBaseMixin, Variable): + """ Underlying object for VariableExpression """ + + @override + def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: + state, indices = state.index(self) + p = PolyMatrixDict() + + rows, cols = self.shape + for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices): + p[row, col] = PolyDict({ + # Create monomial with variable to the first power + # with coefficient of one + MonomialIndex((VariableIndex(index, power=1),)): 1. + }) + + return state, init_poly_matrix(p, self.shape) + |