From 62b86209c6483001909d3d1c2d7f702a081f5208 Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Wed, 17 Aug 2022 17:23:05 +0200 Subject: improve error message in case a polynomial contains a unknown variable --- polymatrix/__init__.py | 51 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 2863164..78f27b8 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -7,6 +7,7 @@ import scipy.sparse import sympy from polymatrix.expression.expression import Expression +from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr from polymatrix.expression.init.initexpression import init_expression @@ -302,6 +303,9 @@ class MatrixBuffer: n_row: int n_param: int + def get_max_degree(self): + return max(degree for degree in self.data.keys()) + def add_buffer(self, index: int): if index <= 1: buffer = np.zeros((self.n_row, self.n_param**index), dtype=np.double) @@ -378,6 +382,9 @@ class MatrixRepresentations: if 2 <= max_idx: def func(x: np.ndarray) -> np.ndarray: + if isinstance(x, tuple) or isinstance(x, list): + x = np.array(x).reshape(-1, 1) + def acc_x_powers(acc, _): next = (acc @ x.T).reshape(-1, 1) return next @@ -471,6 +478,7 @@ def to_matrix_repr( # ordered_variable_index = tuple(sorted(set(gen_used_variables()))) # else: + state, ordered_variable_index = get_variable_indices(state, variables) variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} @@ -495,10 +503,17 @@ def to_matrix_repr( continue for monomial, value in underlying_terms.items(): - try: - new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count)) - except KeyError: - raise KeyError(f'{monomial=} is incompatible with {variable_index_map=}') + def gen_new_monomial(): + for var, count in monomial: + try: + new_variable = variable_index_map[var] + except KeyError: + raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}') + + for _ in range(count): + yield new_variable + + new_monomial = tuple(gen_new_monomial()) cols = monomial_to_index(n_param, new_monomial) @@ -507,7 +522,6 @@ def to_matrix_repr( for col in cols: buffer.add(row, col, sum(count for _, count in monomial), col_value) - # yield A, B, C yield buffer underlying_matrices = tuple(gen_underlying_matrices()) @@ -592,8 +606,31 @@ def to_sympy_expr( sympy_monomial = 1 for offset, count in monomial: - var = state.get_variable_from_offset(offset) - sympy_monomial *= var**count + + variable = state.get_key_from_offset(offset) + # def get_variable_from_offset(offset: int): + # for variable, (start, end) in state.offset_dict.items(): + # if start <= offset < end: + # assert end - start == 1, f'{start=}, {end=}, {variable=}' + + if isinstance(variable, sympy.core.symbol.Symbol): + variable_name = variable.name + elif isinstance(variable, ParametrizeExprMixin): + variable_name = variable.name + elif isinstance(variable, str): + variable_name = variable + else: + raise Exception(f'{variable=}') + + start, end = state.offset_dict[variable] + + if end - start == 1: + sympy_var = sympy.Symbol(variable_name) + else: + sympy_var = sympy.Symbol(f'{variable_name}_{offset - start + 1}') + + # var = get_variable_from_offset(offset) + sympy_monomial *= sympy_var**count sympy_polynomial += value * sympy_monomial -- cgit v1.2.1