summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py4
-rw-r--r--polymatrix/expression/__init__.py104
-rw-r--r--polymatrix/expression/expression.py12
-rw-r--r--polymatrix/expression/from_.py125
-rw-r--r--polymatrix/expression/impl.py28
-rw-r--r--polymatrix/expression/init.py13
-rw-r--r--polymatrix/expression/mixins/nsexprmixin.py60
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py73
-rw-r--r--polymatrix/expression/typing.py15
-rw-r--r--polymatrix/polymatrix/init.py1
-rw-r--r--polymatrix/statemonad.py1
11 files changed, 260 insertions, 176 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index cd1a4e1..6e27fae 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -7,7 +7,7 @@ from polymatrix.expression.from_ import (
from_ as internal_from,
from_names as internal_from_names,
from_name as internal_from_name,
- from_statemonad as internal_from_statemonad,
+ from_state_monad as internal_from_state_monad,
)
from polymatrix.expression import (
@@ -53,4 +53,4 @@ to_affine = to_affine_expression
from_ = internal_from
from_names = internal_from_names
from_name = internal_from_name
-from_statemonad = internal_from_statemonad
+from_state_monad = internal_from_state_monad
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 90cf425..87ee99d 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -1,53 +1,78 @@
-import polymatrix.expression.from_
-import polymatrix.expression.impl
-
from collections.abc import Iterable
+from typing import Callable
+from functools import wraps
+
+import polymatrix.expression.from_ as from_
+import polymatrix.expression.impl
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.expression import init_expression, Expression
-# NP: Why are these functions here?
+from polymatrix.expression.init import (
+ init_block_diag_expr,
+ init_concatenate_expr,
+ init_lower_triangular_expr,
+ init_ns_expr,
+)
-# FIXME: move in the correct file
-def v_stack(
- expressions: Iterable[Expression],
-) -> Expression:
- def gen_underlying():
- for expr in expressions:
- if isinstance(expr, Expression):
- yield expr.underlying
- else:
- yield polymatrix.expression.from_.from_(expr).underlying
+def convert_args_to_expression(fn: Callable) -> Callable:
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ wrapped_args = (
+ from_.from_any(arg)
+ for arg in args
+ )
- return init_expression(
- underlying=polymatrix.expression.impl.VStackExprImpl(
- underlying=tuple(gen_underlying()),
- ),
- )
+ wrapped_kwargs = {
+ kw: from_.from_any(arg)
+ for kw, arg in kwargs.items()
+ }
+ return fn(*wrapped_args, **wrapped_kwargs)
+ return wrapper
-# FIXME: move in the correct file
-def h_stack(
- expressions: Iterable[Expression],
-) -> Expression:
- return v_stack((expr.T for expr in expressions)).T
+@convert_args_to_expression
+def ones(shape: Expression):
+ """ Make a matrix filled with ones """
+ return init_expression(init_ns_expr(n=1., shape=shape.underlying))
-# FIXME: move in the correct file
-def block_diag(
- expressions: tuple[Expression],
-) -> Expression:
- return init_expression(
- underlying=polymatrix.expression.impl.BlockDiagExprImpl(
- underlying=expressions,
- )
- )
+
+@convert_args_to_expression
+def zeros(shape: Expression):
+ """ Make a matrix filled with zeros """
+ return init_expression(init_ns_expr(n=1., shape=shape.underlying))
+
+
+def v_stack(expressions: Iterable) -> Expression:
+ """ Vertically stack expressions """
+ # arrange blocks vertically and concatanete
+ blocks = tuple((from_.from_any(e).underlying,) for e in expressions)
+ return init_expression(init_concatenate_expr(blocks))
+
+
+def h_stack(expressions: Iterable) -> Expression:
+ """ Horizontally stack expressions """
+ # arrange horizontally
+ blocks = (tuple(from_.from_any(e).underlying for e in expressions),)
+ return init_expression(init_concatenate_expr(blocks))
+
+
+def concatenate(expressions: Iterable[Iterable]):
+ """ Concatenate arrays (more general version of vstack and hstack) """
+ blocks = tuple(tuple(from_.from_any(expr).underlying for expr in row) for row in expressions)
+ return init_expression(init_concatenate_expr(blocks))
+
+
+@convert_args_to_expression
+def block_diag(expressions: tuple[Expression]) -> Expression:
+ """ Create a block diagonal matrix. """
+ return init_expression(init_block_diag_expr(expressions))
-# FIXME: move in the correct file
def product(
expressions: Iterable[Expression],
- degrees: tuple[int, ...] = None,
+ degrees: tuple[int, ...] | None = None,
):
return init_expression(
underlying=polymatrix.expression.impl.ProductExprImpl(
@@ -58,11 +83,6 @@ def product(
)
-def concatenate(arrays: Iterable[Iterable[Expression]]):
- return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl(
- tuple(tuple(expr.underlying for expr in row) for row in arrays)))
-
-
+@convert_args_to_expression
def lower_triangular(vector: Expression):
- return init_expression(
- underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector))
+ return init_expression(init_lower_triangular_expr(vector.underlying))
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index db39780..49164d0 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -136,10 +136,10 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def combinations(
- self,
- degrees: tuple[int, ...] | int,
- ):
+ def combinations(self, degrees: tuple[int, ...] | int | Expression):
+ if isinstance(degrees, Expression):
+ degrees = degrees.underlying
+
return self.copy(
underlying=polymatrix.expression.init.init_combinations_expr(
expression=self.underlying,
@@ -147,14 +147,14 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def degree(self) -> "Expression":
+ def degree(self) -> Expression:
return self.copy(
underlying=degree(
underlying=self.underlying,
),
)
- def determinant(self) -> "Expression":
+ def determinant(self) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_determinant_expr(
underlying=self.underlying,
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 8dbc055..a66580d 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -1,40 +1,69 @@
-from typing import Iterable
-from typing_extensions import override
-from polymatrix.expression.typing import FromDataTypes
+import sympy
+import numpy
-import polymatrix.expression.init
+from numpy.typing import NDArray
+from typing import Iterable, Any
+
+import polymatrix.expression.init as init
-from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.typing import FromSupportedTypes
from polymatrix.statemonad import StateMonad
+from polymatrix.expression.expression import (
+ init_expression,
+ Expression,
+ init_variable_expression,
+ VariableExpression
+)
+
+
+def from_any(value: FromSupportedTypes) -> Expression:
+ if v := from_any_or(value, None):
+ return v
+
+ raise ValueError("Unsupported type. Cannot construct expression "
+ f"from value {value} with type {type(value)}")
+
+
+def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expression | Any:
+ if isinstance(value, VariableExpression | Expression):
+ return value
+
+ elif isinstance(value, ExpressionBaseMixin):
+ return init_expression(value)
+
+ elif isinstance(value, int | float):
+ return from_number(value)
-# NP: this function name makes no sense to me,
-def from_expr_or_none(
- data: FromDataTypes,
-) -> Expression:
- return init_expression(
- underlying=polymatrix.expression.init.init_from_expr_or_none(
- data=data,
- ),
- )
+ elif isinstance(value, StateMonad):
+ return from_state_monad(value)
+ elif isinstance(value, numpy.ndarray):
+ return from_numpy(value)
-# NP: from is not an ideal name because it is a keyword
-# NP: consider differnt name like make_from?
-def from_(
- data: FromDataTypes,
-) -> Expression:
- return init_expression(
- underlying=polymatrix.expression.init.init_from_expr(
- data=data,
- ),
- )
+ elif isinstance(value, sympy.Matrix | sympy.Expr):
+ return from_sympy(value)
+ elif isinstance(value, tuple):
+ if len(value) < 1:
+ return value_if_not_supported
-def from_statemonad(monad: StateMonad) -> Expression:
- return init_expression(polymatrix.expression.init.init_from_statemonad(monad))
+ if isinstance(value[0], tuple):
+ if isinstance(value[0], int | float):
+ return from_numbers(value)
+
+ elif isinstance(value[0], sympy.Expr):
+ return from_sympy(value)
+
+ elif isinstance(value[0], int | float):
+ return from_numbers(value)
+
+ elif isinstance(value[0], sympy.Expr):
+ return from_sympy((value,))
+
+ return value_if_not_supported
def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]:
@@ -48,4 +77,46 @@ def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -
if isinstance(shape, Expression):
shape = shape.underlying
- return init_variable_expression(polymatrix.expression.init.init_variable_expr(name, shape))
+ return init_variable_expression(init.init_variable_expr(name, shape))
+
+
+def from_number(num: int | float) -> Expression:
+ """ Construct an expression from a number. """
+ return init_expression(init.init_from_numbers_expr(((num,),)))
+
+
+def from_numbers(nums: Iterable[int | float] | Iterable[Iterable[int | float]]) -> Expression:
+ """ Construct vector or matrix from numbers. """
+ numbers = tuple(nums)
+
+ # Row vector
+ if isinstance(numbers[0], int | float):
+ return init_expression(init.init_from_numbers_expr((numbers,)))
+
+ # Matrix
+ numbers = tuple(tuple(n) for row in numbers for n in row)
+ return init_expression(init.init_from_numbers_expr(numbers))
+
+
+def from_numpy(array: NDArray) -> Expression:
+ """ Convert a Numpy Array into an Expression. """
+ return init_expression(init.init_from_numpy_expr(array))
+
+
+def from_state_monad(monad: StateMonad) -> Expression:
+ return init_expression(init.init_from_statemonad(monad))
+
+
+def from_sympy(expr: sympy.Matrix | sympy.Expr | tuple[tuple[sympy.Expr, ...], ...]) -> Expression:
+ return init_expression(init.init_from_sympy_expr(expr))
+
+
+# -- Old API --
+
+
+def from_expr_or_none(data: FromSupportedTypes) -> Expression | None:
+ return from_any_or(data, None)
+
+
+def from_(*args, **kwargs) -> Expression:
+ return from_any(*args, **kwargs)
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 9451ba6..ad5b9b8 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -35,6 +35,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial
from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
+from polymatrix.expression.mixins.nsexprmixin import NsExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin
from polymatrix.expression.mixins.productexprmixin import ProductExprMixin
@@ -55,7 +56,6 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricM
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
from polymatrix.expression.mixins.variablemixin import VariableMixin
-from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
from polymatrix.expression.mixins.fromtermsexprmixin import (
FromPolynomialDataExprMixin,
@@ -102,6 +102,13 @@ class CombinationsExprImpl(CombinationsExprMixin):
class ConcatenateExprImpl(ConcatenateExprMixin):
blocks: tuple[tuple[ExpressionBaseMixin, ...], ...]
+ def __str__(self):
+ blocks = ", ".join(
+ "(" + ", ".join(str(b) for b in row) + ")"
+ for row in self.blocks)
+
+ return f"cat({blocks})"
+
@dataclassabc.dataclassabc(frozen=True)
class DerivativeExprImpl(DerivativeExprMixin):
@@ -289,6 +296,12 @@ class MaxExprImpl(MaxExprMixin):
underlying: ExpressionBaseMixin
+@dataclassabc.dataclassabc(frozen=True)
+class NsExprImpl(NsExprMixin):
+ n: int | float | ExpressionBaseMixin
+ shape: tuple[int, int] | ExpressionBaseMixin
+
+
@dataclassabc.dataclassabc(frozen=True, repr=False)
class ParametrizeExprImpl(ParametrizeExprMixin):
underlying: ExpressionBaseMixin
@@ -375,7 +388,7 @@ class ShapeExprImpl(ShapeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
class SliceExprImpl(SliceExprMixin):
underlying: ExpressionBaseMixin
- slice: tuple # See SlicePolyMatrixMixin for details of this tuple
+ slice: tuple # See SlicePolyMatrixMixin
@dataclassabc.dataclassabc(frozen=True)
@@ -437,14 +450,3 @@ class VariableImpl(VariableMixin):
def __str__(self):
return self.name
-
-
-@dataclassabc.dataclassabc(frozen=True)
-class VStackExprImpl(VStackExprMixin):
- underlying: tuple[ExpressionBaseMixin, ...]
-
- def __str__(self):
- inner = ", ".join(map(str, self.underlying))
- return f"v_stack({inner})"
-
-
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index c109156..b51b92c 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -3,7 +3,7 @@ import numpy as np
import numpy.typing as npt
import sympy
-from polymatrix.expression.typing import FromDataTypes
+from polymatrix.expression.typing import FromSupportedTypes
import polymatrix.expression.impl
@@ -48,7 +48,7 @@ def init_cache_expr(
def init_combinations_expr(
expression: ExpressionBaseMixin,
- degrees: tuple[int, ...] | int,
+ degrees: tuple[int, ...] | int | ExpressionBaseMixin,
):
if isinstance(degrees, int):
degrees = (degrees,)
@@ -157,10 +157,11 @@ def init_from_statemonad(monad: StateMonad):
return polymatrix.expression.impl.FromStateMonadImpl(monad=monad)
+# TODO: remove this function, replaced by from_any
# NP: this function should be split up into smaller functions, one for each "from" type
# NP: and each "from" should be documented, explaining how it is interpreted.
def init_from_expr_or_none(
- data: FromDataTypes,
+ data: FromSupportedTypes,
) -> ExpressionBaseMixin | None:
if isinstance(data, VariableExpression):
return data
@@ -208,7 +209,7 @@ def init_from_expr_or_none(
return None
-def init_from_expr(data: FromDataTypes):
+def init_from_expr(data: FromSupportedTypes):
expr = init_from_expr_or_none(data)
if expr is None:
@@ -295,6 +296,10 @@ def init_max_expr(
)
+def init_ns_expr(n: int | float | ExpressionBaseMixin, shape: tuple[int, int] | ExpressionBaseMixin):
+ return polymatrix.expression.impl.NsExprImpl(n, shape)
+
+
def init_parametrize_expr(
underlying: ExpressionBaseMixin,
name: str = None, # FIXME: typing
diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py
new file mode 100644
index 0000000..f97d724
--- /dev/null
+++ b/polymatrix/expression/mixins/nsexprmixin.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from math import sqrt, isclose
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.index import PolyDict, MonomialIndex
+from polymatrix.polymatrix.init import init_broadcast_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+
+class NsExprMixin(ExpressionBaseMixin):
+ """
+ Make a matrix or vector that is filled with N's. This is used to provide
+ functions line ones() or zeros(), i.e. N = 1 and N = 0 respectively. N may
+ not only be a number but also any scalar expression.
+ """
+
+ @property
+ @abstractmethod
+ def n(self) -> int | float | ExpressionBaseMixin:
+ """ Value that is used to fill the matrix """
+
+ @property
+ @abstractmethod
+ def shape(self) -> tuple[int, int] | ExpressionBaseMixin:
+ """
+ Shape of the ones. If it is an expression it must evaluate to a 2d
+ column vector of integers.
+ """
+
+ @override
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ if isinstance(self.n, ExpressionBaseMixin):
+ state, pm = self.n.apply(state)
+ n = pm.scalar()
+
+ else:
+ n = PolyDict({ MonomialIndex.constant(): 1. })
+
+ if isinstance(self.shape, ExpressionBaseMixin):
+ state, pm = self.shape.apply(state)
+ if pm.shape != (2, 1):
+ raise ValueError("Shape must evaluate to a 2d column vector "
+ f"but it has shape {pm.shape}")
+
+ # FIXME: should check that they are actually integers
+ nrows = int(pm.at(0, 0).constant())
+ ncols = int(pm.at(1, 0).constant())
+
+ else:
+ nrows, ncols = self.shape
+
+ return state, init_broadcast_poly_matrix(n, shape=(nrows, ncols))
+
+
+
+
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
deleted file mode 100644
index 249bb26..0000000
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import abc
-import itertools
-import dataclassabc
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.polymatrix.index import PolyDict
-
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.expressionstate import ExpressionState
-
-
-class VStackExprMixin(ExpressionBaseMixin):
- """
- Vertical stacking of the underlying polynomial matrices
-
- [[1, 2]], [[3, 4]] -> [[1, 2], [3, 4]]
- """
-
- @property
- @abc.abstractmethod
- def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ...
-
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
- all_underlying = []
- for expr in self.underlying:
- state, polymatrix = expr.apply(state=state)
- all_underlying.append(polymatrix)
-
- for underlying in all_underlying:
- assert (
- underlying.shape[1] == all_underlying[0].shape[1]
- ), f"{underlying.shape[1]} not equal {all_underlying[0].shape[1]}"
-
- # FIXME: move to polymatrix module
- @dataclassabc.dataclassabc(frozen=True)
- class VStackPolyMatrix(PolyMatrixMixin):
- all_underlying: tuple[PolyMatrixMixin]
- underlying_row_range: tuple[tuple[int, int], ...]
- shape: tuple[int, int]
-
- def at(self, row: int, col: int) -> PolyDict:
- for polymatrix, (row_start, row_end) in zip(
- self.all_underlying, self.underlying_row_range
- ):
- if row_start <= row < row_end:
- return polymatrix.at(
- row=row - row_start,
- col=col,
- )
-
- raise Exception(f"row {row} is out of bounds")
-
- underlying_row_range = tuple(
- itertools.pairwise(
- itertools.accumulate(
- (expr.shape[0] for expr in all_underlying), initial=0
- )
- )
- )
-
- n_row = sum(polymatrix.shape[0] for polymatrix in all_underlying)
-
- polymatrix = VStackPolyMatrix(
- all_underlying=all_underlying,
- shape=(n_row, all_underlying[0].shape[1]),
- underlying_row_range=underlying_row_range,
- )
-
- return state, polymatrix
diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py
index af762a0..87ade26 100644
--- a/polymatrix/expression/typing.py
+++ b/polymatrix/expression/typing.py
@@ -1,18 +1,17 @@
from __future__ import annotations
+from numpy.typing import NDArray
+import sympy
+
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.statemonad import StateMonad
-import numpy.typing as npt
-import sympy
-
-FromDataTypes = (
+FromSupportedTypes = (
str
- | npt.NDArray
- | sympy.Matrix
- | sympy.Expr
- | tuple
+ | NDArray
+ | sympy.Matrix | sympy.Expr
+ | tuple[tuple[...]]
| ExpressionBaseMixin
| StateMonad
)
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index b8af9f7..ac6d199 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -19,7 +19,6 @@ if TYPE_CHECKING:
from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin, SlicePolyMatrixMixin
-# FIXME: use polymatrix.typing
def init_poly_matrix(
data: PolyMatrixDict,
shape: tuple[int, int],
diff --git a/polymatrix/statemonad.py b/polymatrix/statemonad.py
index bfd63de..f3ddfd6 100644
--- a/polymatrix/statemonad.py
+++ b/polymatrix/statemonad.py
@@ -25,6 +25,7 @@ class StateMonad(Generic[State, U]):
# was passed to the statemonads that are applied to expressions. For
# example in to_sympy, you want so see what expression is converted to
# sympy.
+
# arguments that were given to the function apply_func.
# this field is optional
arguments: U | None