summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/__init__.py61
-rw-r--r--polymatrix/expression/from_.py11
-rw-r--r--polymatrix/expression/typing.py1
3 files changed, 21 insertions, 52 deletions
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index ed3ba2e..c7c365d 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -7,6 +7,7 @@ import polymatrix.expression.impl
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.expression import init_expression, Expression
+from polymatrix.expression.from_ import from_any
from polymatrix.expression.mixins.namedexprmixin import NamedExprMixin
from polymatrix.expression.init import (
@@ -41,30 +42,14 @@ def ones(shape: int | tuple[int | Expression, int | Expression] | Expression) ->
expr = init_ns_expr(n=1, shape=shape.underlying)
elif isinstance(shape, int):
- expr = init_ns_expr(n=1, shape=(shape, 1))
+ expr = init_ns_expr(n=1, shape=v_stack((shape, 1)))
elif isinstance(shape, tuple):
- if isinstance(shape[0], Expression):
- nrows = shape[0].underlying
+ nrows, ncols = shape
+ if (not isinstance(nrows, int)) or (not isinstance(ncols, int)):
+ shape = from_any(((nrows,), (ncols,))).underlying
- elif isinstance(shape[0], int):
- nrows = shape[0]
-
- else:
- raise TypeError("Number of rows must be an integer or an expression "
- "that evaluates to an integer")
-
- if isinstance(shape[1], Expression):
- ncols = shape[1].underlying
-
- elif isinstance(shape[1], int):
- ncols = shape[1]
-
- else:
- raise TypeError("Number of columns must be an integer or an expression "
- "that evaluates to an integer")
-
- expr = init_ns_expr(n=1, shape=(nrows, ncols))
+ expr = init_ns_expr(n=1, shape=shape)
else:
# TODO: error messages
@@ -98,30 +83,14 @@ def zeros(shape: int | tuple[int | Expression, int | Expression] | Expression):
expr = init_ns_expr(n=0, shape=shape.underlying)
elif isinstance(shape, int):
- expr = init_ns_expr(n=0, shape=(shape, 1))
+ expr = init_ns_expr(n=0, shape=v_stack((shape, 1)))
elif isinstance(shape, tuple):
- if isinstance(shape[0], Expression):
- nrows = shape[0].underlying
-
- elif isinstance(shape[0], int):
- nrows = shape[0]
+ nrows, ncols = shape
+ if (not isinstance(nrows, int)) or (not isinstance(ncols, int)):
+ shape = from_any(((nrows,), (ncols,))).underlying
- else:
- raise TypeError("Number of rows must be an integer or an expression "
- "that evaluates to an integer")
-
- if isinstance(shape[1], Expression):
- ncols = shape[1].underlying
-
- elif isinstance(shape[1], int):
- ncols = shape[1]
-
- else:
- raise TypeError("Number of columns must be an integer or an expression "
- "that evaluates to an integer")
-
- expr = init_ns_expr(n=0, shape=(nrows, ncols))
+ expr = init_ns_expr(n=0, shape=shape)
else:
# TODO: error messages
@@ -134,14 +103,18 @@ def v_stack(expressions: Iterable) -> Expression:
""" Vertically stack expressions """
# arrange blocks vertically and concatanete
blocks = tuple((from_.from_any(e).underlying,) for e in expressions)
- return init_expression(init_concatenate_expr(blocks))
+ u = init_expression(init_concatenate_expr(blocks))
+ names = ", ".join(str(b[0]) for b in blocks)
+ return give_name(u, f"vstack({names})")
def h_stack(expressions: Iterable) -> Expression:
""" Horizontally stack expressions """
# arrange horizontally
blocks = (tuple(from_.from_any(e).underlying for e in expressions),)
- return init_expression(init_concatenate_expr(blocks))
+ u = init_expression(init_concatenate_expr(blocks))
+ names = ", ".join(str(b) for b in blocks[0])
+ return give_name(u, f"hstack({names})")
def concatenate(expressions: Iterable[Iterable]):
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 03091c7..3be44e5 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -83,7 +83,7 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre
# row vector tuple[...]
- elif all(not isinstance(row, tuple) for v in value):
+ elif all(not isinstance(row, tuple) for row in value):
wrapped_rows: list[Expression] = []
for row in value:
wrapped = from_any_or(row, None)
@@ -119,13 +119,8 @@ def from_name(
elif isinstance(shape, tuple):
nrows, ncols = shape
- if isinstance(nrows, Expression):
- nrows = nrows.underlying
-
- if isinstance(ncols, Expression):
- ncols = ncols.underlying
-
- shape = (nrows, ncols)
+ if (not isinstance(nrows, int)) or (not isinstance(ncols, int)):
+ shape = from_any(((nrows,), (ncols,))).underlying
return init_variable_expression(init.init_variable_expr(Symbol(name), shape))
diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py
index 87ade26..23d7f66 100644
--- a/polymatrix/expression/typing.py
+++ b/polymatrix/expression/typing.py
@@ -11,6 +11,7 @@ FromSupportedTypes = (
str
| NDArray
| sympy.Matrix | sympy.Expr
+ | tuple[...]
| tuple[tuple[...]]
| ExpressionBaseMixin
| StateMonad