summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/expression.py51
-rw-r--r--polymatrix/expression/from_.py5
-rw-r--r--polymatrix/expression/impl.py8
-rw-r--r--polymatrix/expression/init.py8
-rw-r--r--polymatrix/expression/mixins/variablemixin.py34
5 files changed, 74 insertions, 32 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 9d910db..a7e2025 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -2,7 +2,6 @@ from __future__ import annotations
import dataclasses
import typing
-import itertools
import numpy as np
from abc import ABC, abstractmethod
@@ -18,9 +17,8 @@ import polymatrix.expression.init
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.op import (
diff,
integrate,
@@ -508,7 +506,7 @@ class ExpressionImpl(Expression):
def __repr__(self) -> str:
return self.underlying.__repr__()
- def copy(self, underlying: ExpressionBaseMixin) -> "Expression":
+ def copy(self, underlying: ExpressionBaseMixin) -> Expression:
return dataclasses.replace(
self,
underlying=underlying,
@@ -528,39 +526,32 @@ class VariableExpression(Expression, Variable):
Expression that is a polynomial variable, i.e. an expression that cannot be
reduced further.
"""
- @property
- @override
- def underlying(self):
- return None
-
- @override
- def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
- # Since there is no underlying this class directly does the job of
- # creating a polymatrix
- state, indices = state.index(self)
- p = PolyMatrixDict()
+ def __hash__(self):
+ return self.underlying.__hash__()
- rows, cols = self.shape
- for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices):
- p[row, col] = PolyDict({
- # Create monomial with variable to the first power
- # with coefficient of one
- MonomialIndex((VariableIndex(index, power=1),)): 1.
- })
+ @override
+ @property
+ def name(self):
+ return self.underlyng.name
- return state, init_poly_matrix(p, self.shape)
+ @override
+ @property
+ def shape(self):
+ return self.underlyng.shape
@dataclassabc(frozen=True)
class VariableExpressionImpl(VariableExpression):
- name: str
- shape: tuple[int, int]
+ underlying: ExpressionBaseMixin
+
+ def __repr__(self) -> str:
+ return self.underlying.__repr__()
+
+ def copy(self, underlying: ExpressionBaseMixin) -> Expression:
+ return init_expression(underlying)
- @override
- def copy(self, underlying: ExpressionBaseMixin) -> "Expression":
- return self
+def init_variable_expression(underlying: VariableMixin):
+ return VariableExpressionImpl(underlying)
-def init_variable_expression(name: str, shape: tuple[int, int] = (1, 1)):
- return VariableExpressionImpl(name, shape)
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 6fdb672..a1303a3 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -35,8 +35,9 @@ def from_(
def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableExpression] | VariableExpression:
""" Construct one or multiple variables from comma separated a list of names. """
- variables = tuple(init_variable_expression(name.strip(), shape)
- for name in names.split(","))
+ variables = tuple(init_variable_expression(
+ underlying=polymatrix.expression.init.init_variable_expr(name.strip(), shape))
+ for name in names.split(","))
if len(variables) == 1:
return variables[0]
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 8551fc5..119d8c7 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -1,5 +1,6 @@
import numpy.typing
import sympy
+from typing_extensions import override
import dataclassabc
from polymatrix.expression.mixins.integrateexprmixin import IntegrateExprMixin
@@ -65,6 +66,7 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import (
)
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
+from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
from polymatrix.expression.mixins.tosortedvariablesmixin import (
ToSortedVariablesExprMixin,
@@ -375,5 +377,11 @@ class TruncateExprImpl(TruncateExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class VariableImpl(VariableMixin):
+ name: str
+ shape: tuple[int, int]
+
+
+@dataclassabc.dataclassabc(frozen=True)
class VStackExprImpl(VStackExprMixin):
underlying: tuple
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 0e6faf9..131dd82 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -15,6 +15,7 @@ from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.utils.formatsubstitutions import format_substitutions
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl import AdditionExprImpl
+from polymatrix.expression.expression import VariableExpression
def init_addition_expr(
@@ -174,6 +175,9 @@ def init_from_sympy_expr(data: sympy.Expr | sympy.Matrix | tuple[tuple[sympy.Exp
def init_from_expr_or_none(
data: FromDataTypes,
) -> ExpressionBaseMixin | None:
+ if isinstance(data, VariableExpression):
+ return data
+
if isinstance(data, str):
return init_parametrize_expr(
underlying=init_from_expr_or_none(1), # FIXME: typing
@@ -529,3 +533,7 @@ def init_truncate_expr(
degrees=degrees,
inverse=inverse,
)
+
+
+def init_variable_expr(name: str, shape: tuple[int, int]):
+ return polymatrix.expression.impl.VariableImpl(name, shape)
diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py
new file mode 100644
index 0000000..01968e4
--- /dev/null
+++ b/polymatrix/expression/mixins/variablemixin.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import typing
+import itertools
+from typing_extensions import override
+
+if typing.TYPE_CHECKING:
+ from polymatrix.expressionstate.mixins import ExpressionStateMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
+from polymatrix.variable.abc import Variable
+
+
+class VariableMixin(ExpressionBaseMixin, Variable):
+ """ Underlying object for VariableExpression """
+
+ @override
+ def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ state, indices = state.index(self)
+ p = PolyMatrixDict()
+
+ rows, cols = self.shape
+ for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices):
+ p[row, col] = PolyDict({
+ # Create monomial with variable to the first power
+ # with coefficient of one
+ MonomialIndex((VariableIndex(index, power=1),)): 1.
+ })
+
+ return state, init_poly_matrix(p, self.shape)
+