summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-04-24 12:26:13 +0200
committerNao Pross <np@0hm.ch>2024-04-24 12:26:13 +0200
commit821945807fa76c8edb0b36ef4ee741e872f0c0e2 (patch)
tree2b1ae4db76bf5405a239b46b44a841568045911c
parentAdd function to construct variable from_names (diff)
downloadpolymatrix-821945807fa76c8edb0b36ef4ee741e872f0c0e2.tar.gz
polymatrix-821945807fa76c8edb0b36ef4ee741e872f0c0e2.zip
Fix expression.init.init_variable_expr
The impl class has to be only in the underlying to have the behaviour of expressions
-rw-r--r--polymatrix/expression/expression.py8
-rw-r--r--polymatrix/expression/from_.py2
-rw-r--r--polymatrix/expression/init.py3
3 files changed, 7 insertions, 6 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 8d53354..f52a0ea 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -495,8 +495,8 @@ class Expression(
)
-# NP: why is this impl class here?
-# FIXME: move to impl.py
+# This is here and not in impl.py because of circular imports
+# FIXME: this is not ideal
@dataclassabc.dataclassabc(frozen=True, repr=False)
class ExpressionImpl(Expression):
underlying: ExpressionBaseMixin
@@ -512,8 +512,8 @@ class ExpressionImpl(Expression):
-# NP: why is this here and not in init.py?
-# FIXME: move to init.py
+# This is here and not in init.py because of circular imports
+# FIXME: this is not ideal
def init_expression(
underlying: ExpressionBaseMixin,
):
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index c34e75d..c833e8b 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -37,7 +37,7 @@ def from_(
)
-def from_names(names: str, shape: tuple[int, int] = (1,1)) -> VariableMixin | tuple[VariableMixin]:
+def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableMixin] | VariableMixin:
""" Construct one or multiple variables from comma separated a list of names. """
variables = tuple(init_variable_expr(name.strip(), shape)
for name in names.split(","))
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 5acb40d..4959e11 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -527,4 +527,5 @@ def init_truncate_expr(
def init_variable_expr(name: str, shape: tuple[int, int] = (1, 1)):
- return polymatrix.expression.impl.VariableImpl(name, shape)
+ return polymatrix.expression.expression.init_expression(
+ underlying=polymatrix.expression.impl.VariableImpl(name, shape))