summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-17 17:23:05 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-17 17:23:05 +0200
commit62b86209c6483001909d3d1c2d7f702a081f5208 (patch)
tree49d484751ee586d6c8c44f598c32289f4d82f401
parentUpdate README (diff)
downloadpolymatrix-62b86209c6483001909d3d1c2d7f702a081f5208.tar.gz
polymatrix-62b86209c6483001909d3d1c2d7f702a081f5208.zip
improve error message in case a polynomial contains a unknown variable
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py51
1 files changed, 44 insertions, 7 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 2863164..78f27b8 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -7,6 +7,7 @@ import scipy.sparse
import sympy
from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
from polymatrix.expression.init.initexpression import init_expression
@@ -302,6 +303,9 @@ class MatrixBuffer:
n_row: int
n_param: int
+ def get_max_degree(self):
+ return max(degree for degree in self.data.keys())
+
def add_buffer(self, index: int):
if index <= 1:
buffer = np.zeros((self.n_row, self.n_param**index), dtype=np.double)
@@ -378,6 +382,9 @@ class MatrixRepresentations:
if 2 <= max_idx:
def func(x: np.ndarray) -> np.ndarray:
+ if isinstance(x, tuple) or isinstance(x, list):
+ x = np.array(x).reshape(-1, 1)
+
def acc_x_powers(acc, _):
next = (acc @ x.T).reshape(-1, 1)
return next
@@ -471,6 +478,7 @@ def to_matrix_repr(
# ordered_variable_index = tuple(sorted(set(gen_used_variables())))
# else:
+
state, ordered_variable_index = get_variable_indices(state, variables)
variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
@@ -495,10 +503,17 @@ def to_matrix_repr(
continue
for monomial, value in underlying_terms.items():
- try:
- new_monomial = tuple(variable_index_map[var] for var, count in monomial for _ in range(count))
- except KeyError:
- raise KeyError(f'{monomial=} is incompatible with {variable_index_map=}')
+ def gen_new_monomial():
+ for var, count in monomial:
+ try:
+ new_variable = variable_index_map[var]
+ except KeyError:
+ raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
+
+ for _ in range(count):
+ yield new_variable
+
+ new_monomial = tuple(gen_new_monomial())
cols = monomial_to_index(n_param, new_monomial)
@@ -507,7 +522,6 @@ def to_matrix_repr(
for col in cols:
buffer.add(row, col, sum(count for _, count in monomial), col_value)
- # yield A, B, C
yield buffer
underlying_matrices = tuple(gen_underlying_matrices())
@@ -592,8 +606,31 @@ def to_sympy_expr(
sympy_monomial = 1
for offset, count in monomial:
- var = state.get_variable_from_offset(offset)
- sympy_monomial *= var**count
+
+ variable = state.get_key_from_offset(offset)
+ # def get_variable_from_offset(offset: int):
+ # for variable, (start, end) in state.offset_dict.items():
+ # if start <= offset < end:
+ # assert end - start == 1, f'{start=}, {end=}, {variable=}'
+
+ if isinstance(variable, sympy.core.symbol.Symbol):
+ variable_name = variable.name
+ elif isinstance(variable, ParametrizeExprMixin):
+ variable_name = variable.name
+ elif isinstance(variable, str):
+ variable_name = variable
+ else:
+ raise Exception(f'{variable=}')
+
+ start, end = state.offset_dict[variable]
+
+ if end - start == 1:
+ sympy_var = sympy.Symbol(variable_name)
+ else:
+ sympy_var = sympy.Symbol(f'{variable_name}_{offset - start + 1}')
+
+ # var = get_variable_from_offset(offset)
+ sympy_monomial *= sympy_var**count
sympy_polynomial += value * sympy_monomial