diff options
-rw-r--r-- | polymatrix/__init__.py | 90 |
1 files changed, 53 insertions, 37 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 23fe913..64f3d63 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -7,10 +7,11 @@ import numpy as np import scipy.sparse import sympy -from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.expressionmixin import ExpressionMixin as Expression from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.mixins.expressionmixin import ExpressionMixin +# from polymatrix.expression.mixins.expressionmixin import ExpressionMixin from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin +from polymatrix.expression.mixins.parametrizematrixexprmixin import ParametrizeMatrixExprMixin from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr from polymatrix.expression.init.initexpression import init_expression @@ -19,7 +20,7 @@ from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr from polymatrix.expression.init.initvstackexpr import init_v_stack_expr from polymatrix.polymatrix.polymatrix import PolyMatrix -from polymatrix.expression.utils.getvariableindices import get_variable_indices, get_variable_indices_from_variable +from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable from polymatrix.statemonad.init.initstatemonad import init_state_monad from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin from polymatrix.expression.utils.monomialtoindex import monomial_to_index @@ -52,8 +53,11 @@ def from_polymatrix( ) def from_( - data: tuple[tuple[float]], + data: str | tuple[tuple[float]], ): + if isinstance(data, str): + return from_(1).parametrize(data) + return from_sympy(data) def v_stack( @@ -308,7 +312,7 @@ def shape( @dataclasses.dataclass class MatrixBuffer: - data: dict[int, ...] + data: dict[int, np.ndarray] n_row: int n_param: int @@ -394,6 +398,9 @@ class MatrixRepresentations: if isinstance(x, tuple) or isinstance(x, list): x = np.array(x).reshape(-1, 1) + elif x.shape[0] == 1: + x = x.reshape(-1, 1) + def acc_x_powers(acc, _): next = (acc @ x.T).reshape(-1, 1) return next @@ -433,14 +440,14 @@ class MatrixRepresentations: def to_matrix_repr( expressions: Expression | tuple[Expression], - variables: Expression, + variables: Expression = None, sorted: bool = None, ) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]: if isinstance(expressions, Expression): expressions = (expressions,) - assert isinstance(variables, ExpressionBaseMixin), f'{variables=}' + assert isinstance(variables, ExpressionBaseMixin) or variables is None, f'{variables=}' def func(state: ExpressionState): @@ -453,32 +460,36 @@ def to_matrix_repr( return state, underlying_list + (underlying,) - *_, (state, underlying_list) = tuple(itertools.accumulate( + *_, (state, polymatrix_list) = tuple(itertools.accumulate( expressions, acc_underlying_application, initial=(state, tuple()), )) - state, variable_index = get_variable_indices_from_variable(state, variables) - - if sorted: - tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) - - sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + if variables is None: + sorted_variable_index = tuple() else: - sorted_variable_index = variable_index + state, variable_index = get_variable_indices_from_variable(state, variables) - assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables' + if sorted: + tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index) + + sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1])) + + else: + sorted_variable_index = variable_index + + assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables. Make sure you give a unique name for each variables.' variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)} n_param = len(sorted_variable_index) - def gen_underlying_matrices(): - for underlying in underlying_list: + def gen_numpy_matrices(): + for polymatrix in polymatrix_list: - n_row = underlying.shape[0] + n_row = polymatrix.shape[0] buffer = MatrixBuffer( data={}, @@ -487,34 +498,39 @@ def to_matrix_repr( ) for row in range(n_row): - underlying_terms = underlying.get_poly(row, 0) - if underlying_terms is None: + polymatrix_terms = polymatrix.get_poly(row, 0) + + if polymatrix_terms is None: continue + + if len(polymatrix_terms) == 0: + buffer.add(row, 0, 0, 0) - for monomial, value in underlying_terms.items(): + else: + for monomial, value in polymatrix_terms.items(): - def gen_new_monomial(): - for var, count in monomial: - try: - index = variable_index_map[var] - except KeyError: - raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}') + def gen_new_monomial(): + for var, count in monomial: + try: + index = variable_index_map[var] + except KeyError: + raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}') - for _ in range(count): - yield index + for _ in range(count): + yield index - new_monomial = tuple(gen_new_monomial()) + new_monomial = tuple(gen_new_monomial()) - cols = monomial_to_index(n_param, new_monomial) + cols = monomial_to_index(n_param, new_monomial) - col_value = value / len(cols) + col_value = value / len(cols) - for col in cols: - buffer.add(row, col, sum(count for _, count in monomial), col_value) + for col in cols: + buffer.add(row, col, sum(count for _, count in monomial), col_value) yield buffer - underlying_matrices = tuple(gen_underlying_matrices()) + underlying_matrices = tuple(gen_numpy_matrices()) def gen_auxillary_equations(): for key, monomial_terms in state.auxillary_equations.items(): @@ -641,7 +657,7 @@ def to_sympy_repr( if isinstance(variable, sympy.core.symbol.Symbol): variable_name = variable.name - elif isinstance(variable, ParametrizeExprMixin): + elif isinstance(variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)): variable_name = variable.name elif isinstance(variable, str): variable_name = variable |