diff options
author | Nao Pross <np@0hm.ch> | 2024-05-25 15:38:36 +0200 |
---|---|---|
committer | Nao Pross <np@0hm.ch> | 2024-05-25 15:40:34 +0200 |
commit | 2581dc00949cbc0a8088e2e5bf7cfdfb698ff6f1 (patch) | |
tree | 7201e99693d551b6627814577e5f93eb0f58e94a | |
parent | Create shorthand PolyMatrix.scalar() for .at(0,0) (diff) | |
download | polymatrix-2581dc00949cbc0a8088e2e5bf7cfdfb698ff6f1.tar.gz polymatrix-2581dc00949cbc0a8088e2e5bf7cfdfb698ff6f1.zip |
Create zeros(), ones(), rewrite VStack with Concatenate and clean up from_
-rw-r--r-- | polymatrix/__init__.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/__init__.py | 104 | ||||
-rw-r--r-- | polymatrix/expression/expression.py | 12 | ||||
-rw-r--r-- | polymatrix/expression/from_.py | 125 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 28 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 13 | ||||
-rw-r--r-- | polymatrix/expression/mixins/nsexprmixin.py | 60 | ||||
-rw-r--r-- | polymatrix/expression/mixins/vstackexprmixin.py | 73 | ||||
-rw-r--r-- | polymatrix/expression/typing.py | 15 | ||||
-rw-r--r-- | polymatrix/polymatrix/init.py | 1 | ||||
-rw-r--r-- | polymatrix/statemonad.py | 1 |
11 files changed, 260 insertions, 176 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index cd1a4e1..6e27fae 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -7,7 +7,7 @@ from polymatrix.expression.from_ import ( from_ as internal_from, from_names as internal_from_names, from_name as internal_from_name, - from_statemonad as internal_from_statemonad, + from_state_monad as internal_from_state_monad, ) from polymatrix.expression import ( @@ -53,4 +53,4 @@ to_affine = to_affine_expression from_ = internal_from from_names = internal_from_names from_name = internal_from_name -from_statemonad = internal_from_statemonad +from_state_monad = internal_from_state_monad diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 90cf425..87ee99d 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -1,53 +1,78 @@ -import polymatrix.expression.from_ -import polymatrix.expression.impl - from collections.abc import Iterable +from typing import Callable +from functools import wraps + +import polymatrix.expression.from_ as from_ +import polymatrix.expression.impl from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.expression import init_expression, Expression -# NP: Why are these functions here? +from polymatrix.expression.init import ( + init_block_diag_expr, + init_concatenate_expr, + init_lower_triangular_expr, + init_ns_expr, +) -# FIXME: move in the correct file -def v_stack( - expressions: Iterable[Expression], -) -> Expression: - def gen_underlying(): - for expr in expressions: - if isinstance(expr, Expression): - yield expr.underlying - else: - yield polymatrix.expression.from_.from_(expr).underlying +def convert_args_to_expression(fn: Callable) -> Callable: + @wraps(fn) + def wrapper(*args, **kwargs): + wrapped_args = ( + from_.from_any(arg) + for arg in args + ) - return init_expression( - underlying=polymatrix.expression.impl.VStackExprImpl( - underlying=tuple(gen_underlying()), - ), - ) + wrapped_kwargs = { + kw: from_.from_any(arg) + for kw, arg in kwargs.items() + } + return fn(*wrapped_args, **wrapped_kwargs) + return wrapper -# FIXME: move in the correct file -def h_stack( - expressions: Iterable[Expression], -) -> Expression: - return v_stack((expr.T for expr in expressions)).T +@convert_args_to_expression +def ones(shape: Expression): + """ Make a matrix filled with ones """ + return init_expression(init_ns_expr(n=1., shape=shape.underlying)) -# FIXME: move in the correct file -def block_diag( - expressions: tuple[Expression], -) -> Expression: - return init_expression( - underlying=polymatrix.expression.impl.BlockDiagExprImpl( - underlying=expressions, - ) - ) + +@convert_args_to_expression +def zeros(shape: Expression): + """ Make a matrix filled with zeros """ + return init_expression(init_ns_expr(n=1., shape=shape.underlying)) + + +def v_stack(expressions: Iterable) -> Expression: + """ Vertically stack expressions """ + # arrange blocks vertically and concatanete + blocks = tuple((from_.from_any(e).underlying,) for e in expressions) + return init_expression(init_concatenate_expr(blocks)) + + +def h_stack(expressions: Iterable) -> Expression: + """ Horizontally stack expressions """ + # arrange horizontally + blocks = (tuple(from_.from_any(e).underlying for e in expressions),) + return init_expression(init_concatenate_expr(blocks)) + + +def concatenate(expressions: Iterable[Iterable]): + """ Concatenate arrays (more general version of vstack and hstack) """ + blocks = tuple(tuple(from_.from_any(expr).underlying for expr in row) for row in expressions) + return init_expression(init_concatenate_expr(blocks)) + + +@convert_args_to_expression +def block_diag(expressions: tuple[Expression]) -> Expression: + """ Create a block diagonal matrix. """ + return init_expression(init_block_diag_expr(expressions)) -# FIXME: move in the correct file def product( expressions: Iterable[Expression], - degrees: tuple[int, ...] = None, + degrees: tuple[int, ...] | None = None, ): return init_expression( underlying=polymatrix.expression.impl.ProductExprImpl( @@ -58,11 +83,6 @@ def product( ) -def concatenate(arrays: Iterable[Iterable[Expression]]): - return init_expression(underlying=polymatrix.expression.impl.ConcatenateExprImpl( - tuple(tuple(expr.underlying for expr in row) for row in arrays))) - - +@convert_args_to_expression def lower_triangular(vector: Expression): - return init_expression( - underlying=polymatrix.expression.impl.LowerTriangularExprImpl(underlying=vector)) + return init_expression(init_lower_triangular_expr(vector.underlying)) diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index db39780..49164d0 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -136,10 +136,10 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def combinations( - self, - degrees: tuple[int, ...] | int, - ): + def combinations(self, degrees: tuple[int, ...] | int | Expression): + if isinstance(degrees, Expression): + degrees = degrees.underlying + return self.copy( underlying=polymatrix.expression.init.init_combinations_expr( expression=self.underlying, @@ -147,14 +147,14 @@ class Expression(ExpressionBaseMixin, ABC): ), ) - def degree(self) -> "Expression": + def degree(self) -> Expression: return self.copy( underlying=degree( underlying=self.underlying, ), ) - def determinant(self) -> "Expression": + def determinant(self) -> Expression: return self.copy( underlying=polymatrix.expression.init.init_determinant_expr( underlying=self.underlying, diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index 8dbc055..a66580d 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -1,40 +1,69 @@ -from typing import Iterable -from typing_extensions import override -from polymatrix.expression.typing import FromDataTypes +import sympy +import numpy -import polymatrix.expression.init +from numpy.typing import NDArray +from typing import Iterable, Any + +import polymatrix.expression.init as init -from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.typing import FromSupportedTypes from polymatrix.statemonad import StateMonad +from polymatrix.expression.expression import ( + init_expression, + Expression, + init_variable_expression, + VariableExpression +) + + +def from_any(value: FromSupportedTypes) -> Expression: + if v := from_any_or(value, None): + return v + + raise ValueError("Unsupported type. Cannot construct expression " + f"from value {value} with type {type(value)}") + + +def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expression | Any: + if isinstance(value, VariableExpression | Expression): + return value + + elif isinstance(value, ExpressionBaseMixin): + return init_expression(value) + + elif isinstance(value, int | float): + return from_number(value) -# NP: this function name makes no sense to me, -def from_expr_or_none( - data: FromDataTypes, -) -> Expression: - return init_expression( - underlying=polymatrix.expression.init.init_from_expr_or_none( - data=data, - ), - ) + elif isinstance(value, StateMonad): + return from_state_monad(value) + elif isinstance(value, numpy.ndarray): + return from_numpy(value) -# NP: from is not an ideal name because it is a keyword -# NP: consider differnt name like make_from? -def from_( - data: FromDataTypes, -) -> Expression: - return init_expression( - underlying=polymatrix.expression.init.init_from_expr( - data=data, - ), - ) + elif isinstance(value, sympy.Matrix | sympy.Expr): + return from_sympy(value) + elif isinstance(value, tuple): + if len(value) < 1: + return value_if_not_supported -def from_statemonad(monad: StateMonad) -> Expression: - return init_expression(polymatrix.expression.init.init_from_statemonad(monad)) + if isinstance(value[0], tuple): + if isinstance(value[0], int | float): + return from_numbers(value) + + elif isinstance(value[0], sympy.Expr): + return from_sympy(value) + + elif isinstance(value[0], int | float): + return from_numbers(value) + + elif isinstance(value[0], sympy.Expr): + return from_sympy((value,)) + + return value_if_not_supported def from_names(names: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) -> Iterable[VariableExpression]: @@ -48,4 +77,46 @@ def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) - if isinstance(shape, Expression): shape = shape.underlying - return init_variable_expression(polymatrix.expression.init.init_variable_expr(name, shape)) + return init_variable_expression(init.init_variable_expr(name, shape)) + + +def from_number(num: int | float) -> Expression: + """ Construct an expression from a number. """ + return init_expression(init.init_from_numbers_expr(((num,),))) + + +def from_numbers(nums: Iterable[int | float] | Iterable[Iterable[int | float]]) -> Expression: + """ Construct vector or matrix from numbers. """ + numbers = tuple(nums) + + # Row vector + if isinstance(numbers[0], int | float): + return init_expression(init.init_from_numbers_expr((numbers,))) + + # Matrix + numbers = tuple(tuple(n) for row in numbers for n in row) + return init_expression(init.init_from_numbers_expr(numbers)) + + +def from_numpy(array: NDArray) -> Expression: + """ Convert a Numpy Array into an Expression. """ + return init_expression(init.init_from_numpy_expr(array)) + + +def from_state_monad(monad: StateMonad) -> Expression: + return init_expression(init.init_from_statemonad(monad)) + + +def from_sympy(expr: sympy.Matrix | sympy.Expr | tuple[tuple[sympy.Expr, ...], ...]) -> Expression: + return init_expression(init.init_from_sympy_expr(expr)) + + +# -- Old API -- + + +def from_expr_or_none(data: FromSupportedTypes) -> Expression | None: + return from_any_or(data, None) + + +def from_(*args, **kwargs) -> Expression: + return from_any(*args, **kwargs) diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 9451ba6..ad5b9b8 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -35,6 +35,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import LinearMonomial from polymatrix.expression.mixins.lowertriangularexprmixin import LowerTriangularExprMixin from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin +from polymatrix.expression.mixins.nsexprmixin import NsExprMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin from polymatrix.expression.mixins.powerexprmixin import PowerExprMixin from polymatrix.expression.mixins.productexprmixin import ProductExprMixin @@ -55,7 +56,6 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricM from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin from polymatrix.expression.mixins.variablemixin import VariableMixin -from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin from polymatrix.expression.mixins.fromtermsexprmixin import ( FromPolynomialDataExprMixin, @@ -102,6 +102,13 @@ class CombinationsExprImpl(CombinationsExprMixin): class ConcatenateExprImpl(ConcatenateExprMixin): blocks: tuple[tuple[ExpressionBaseMixin, ...], ...] + def __str__(self): + blocks = ", ".join( + "(" + ", ".join(str(b) for b in row) + ")" + for row in self.blocks) + + return f"cat({blocks})" + @dataclassabc.dataclassabc(frozen=True) class DerivativeExprImpl(DerivativeExprMixin): @@ -289,6 +296,12 @@ class MaxExprImpl(MaxExprMixin): underlying: ExpressionBaseMixin +@dataclassabc.dataclassabc(frozen=True) +class NsExprImpl(NsExprMixin): + n: int | float | ExpressionBaseMixin + shape: tuple[int, int] | ExpressionBaseMixin + + @dataclassabc.dataclassabc(frozen=True, repr=False) class ParametrizeExprImpl(ParametrizeExprMixin): underlying: ExpressionBaseMixin @@ -375,7 +388,7 @@ class ShapeExprImpl(ShapeExprMixin): @dataclassabc.dataclassabc(frozen=True) class SliceExprImpl(SliceExprMixin): underlying: ExpressionBaseMixin - slice: tuple # See SlicePolyMatrixMixin for details of this tuple + slice: tuple # See SlicePolyMatrixMixin @dataclassabc.dataclassabc(frozen=True) @@ -437,14 +450,3 @@ class VariableImpl(VariableMixin): def __str__(self): return self.name - - -@dataclassabc.dataclassabc(frozen=True) -class VStackExprImpl(VStackExprMixin): - underlying: tuple[ExpressionBaseMixin, ...] - - def __str__(self): - inner = ", ".join(map(str, self.underlying)) - return f"v_stack({inner})" - - diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index c109156..b51b92c 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -3,7 +3,7 @@ import numpy as np import numpy.typing as npt import sympy -from polymatrix.expression.typing import FromDataTypes +from polymatrix.expression.typing import FromSupportedTypes import polymatrix.expression.impl @@ -48,7 +48,7 @@ def init_cache_expr( def init_combinations_expr( expression: ExpressionBaseMixin, - degrees: tuple[int, ...] | int, + degrees: tuple[int, ...] | int | ExpressionBaseMixin, ): if isinstance(degrees, int): degrees = (degrees,) @@ -157,10 +157,11 @@ def init_from_statemonad(monad: StateMonad): return polymatrix.expression.impl.FromStateMonadImpl(monad=monad) +# TODO: remove this function, replaced by from_any # NP: this function should be split up into smaller functions, one for each "from" type # NP: and each "from" should be documented, explaining how it is interpreted. def init_from_expr_or_none( - data: FromDataTypes, + data: FromSupportedTypes, ) -> ExpressionBaseMixin | None: if isinstance(data, VariableExpression): return data @@ -208,7 +209,7 @@ def init_from_expr_or_none( return None -def init_from_expr(data: FromDataTypes): +def init_from_expr(data: FromSupportedTypes): expr = init_from_expr_or_none(data) if expr is None: @@ -295,6 +296,10 @@ def init_max_expr( ) +def init_ns_expr(n: int | float | ExpressionBaseMixin, shape: tuple[int, int] | ExpressionBaseMixin): + return polymatrix.expression.impl.NsExprImpl(n, shape) + + def init_parametrize_expr( underlying: ExpressionBaseMixin, name: str = None, # FIXME: typing diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py new file mode 100644 index 0000000..f97d724 --- /dev/null +++ b/polymatrix/expression/mixins/nsexprmixin.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from abc import abstractmethod +from math import sqrt, isclose +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.index import PolyDict, MonomialIndex +from polymatrix.polymatrix.init import init_broadcast_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin + +class NsExprMixin(ExpressionBaseMixin): + """ + Make a matrix or vector that is filled with N's. This is used to provide + functions line ones() or zeros(), i.e. N = 1 and N = 0 respectively. N may + not only be a number but also any scalar expression. + """ + + @property + @abstractmethod + def n(self) -> int | float | ExpressionBaseMixin: + """ Value that is used to fill the matrix """ + + @property + @abstractmethod + def shape(self) -> tuple[int, int] | ExpressionBaseMixin: + """ + Shape of the ones. If it is an expression it must evaluate to a 2d + column vector of integers. + """ + + @override + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + if isinstance(self.n, ExpressionBaseMixin): + state, pm = self.n.apply(state) + n = pm.scalar() + + else: + n = PolyDict({ MonomialIndex.constant(): 1. }) + + if isinstance(self.shape, ExpressionBaseMixin): + state, pm = self.shape.apply(state) + if pm.shape != (2, 1): + raise ValueError("Shape must evaluate to a 2d column vector " + f"but it has shape {pm.shape}") + + # FIXME: should check that they are actually integers + nrows = int(pm.at(0, 0).constant()) + ncols = int(pm.at(1, 0).constant()) + + else: + nrows, ncols = self.shape + + return state, init_broadcast_poly_matrix(n, shape=(nrows, ncols)) + + + + diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py deleted file mode 100644 index 249bb26..0000000 --- a/polymatrix/expression/mixins/vstackexprmixin.py +++ /dev/null @@ -1,73 +0,0 @@ -import abc -import itertools -import dataclassabc -from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.polymatrix.index import PolyDict - -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate import ExpressionState - - -class VStackExprMixin(ExpressionBaseMixin): - """ - Vertical stacking of the underlying polynomial matrices - - [[1, 2]], [[3, 4]] -> [[1, 2], [3, 4]] - """ - - @property - @abc.abstractmethod - def underlying(self) -> tuple[ExpressionBaseMixin, ...]: ... - - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - all_underlying = [] - for expr in self.underlying: - state, polymatrix = expr.apply(state=state) - all_underlying.append(polymatrix) - - for underlying in all_underlying: - assert ( - underlying.shape[1] == all_underlying[0].shape[1] - ), f"{underlying.shape[1]} not equal {all_underlying[0].shape[1]}" - - # FIXME: move to polymatrix module - @dataclassabc.dataclassabc(frozen=True) - class VStackPolyMatrix(PolyMatrixMixin): - all_underlying: tuple[PolyMatrixMixin] - underlying_row_range: tuple[tuple[int, int], ...] - shape: tuple[int, int] - - def at(self, row: int, col: int) -> PolyDict: - for polymatrix, (row_start, row_end) in zip( - self.all_underlying, self.underlying_row_range - ): - if row_start <= row < row_end: - return polymatrix.at( - row=row - row_start, - col=col, - ) - - raise Exception(f"row {row} is out of bounds") - - underlying_row_range = tuple( - itertools.pairwise( - itertools.accumulate( - (expr.shape[0] for expr in all_underlying), initial=0 - ) - ) - ) - - n_row = sum(polymatrix.shape[0] for polymatrix in all_underlying) - - polymatrix = VStackPolyMatrix( - all_underlying=all_underlying, - shape=(n_row, all_underlying[0].shape[1]), - underlying_row_range=underlying_row_range, - ) - - return state, polymatrix diff --git a/polymatrix/expression/typing.py b/polymatrix/expression/typing.py index af762a0..87ade26 100644 --- a/polymatrix/expression/typing.py +++ b/polymatrix/expression/typing.py @@ -1,18 +1,17 @@ from __future__ import annotations +from numpy.typing import NDArray +import sympy + from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.statemonad import StateMonad -import numpy.typing as npt -import sympy - -FromDataTypes = ( +FromSupportedTypes = ( str - | npt.NDArray - | sympy.Matrix - | sympy.Expr - | tuple + | NDArray + | sympy.Matrix | sympy.Expr + | tuple[tuple[...]] | ExpressionBaseMixin | StateMonad ) diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index b8af9f7..ac6d199 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin, PolyMatrixMixin, SlicePolyMatrixMixin -# FIXME: use polymatrix.typing def init_poly_matrix( data: PolyMatrixDict, shape: tuple[int, int], diff --git a/polymatrix/statemonad.py b/polymatrix/statemonad.py index bfd63de..f3ddfd6 100644 --- a/polymatrix/statemonad.py +++ b/polymatrix/statemonad.py @@ -25,6 +25,7 @@ class StateMonad(Generic[State, U]): # was passed to the statemonads that are applied to expressions. For # example in to_sympy, you want so see what expression is converted to # sympy. + # arguments that were given to the function apply_func. # this field is optional arguments: U | None |