From 4abcab8da3e553390cd58b42cae52f210d011ada Mon Sep 17 00:00:00 2001
From: Michael Schneeberger <michael.schneeberger@fhnw.ch>
Date: Mon, 21 Feb 2022 16:21:50 +0100
Subject: bugfixes

---
 polymatrix/polystruct.py | 492 +++++++++++++++++++++++++++++++++++------------
 polymatrix/sympyutils.py |  55 +++---
 2 files changed, 393 insertions(+), 154 deletions(-)

diff --git a/polymatrix/polystruct.py b/polymatrix/polystruct.py
index 64c0f42..202f6cc 100644
--- a/polymatrix/polystruct.py
+++ b/polymatrix/polystruct.py
@@ -1,11 +1,13 @@
 import abc
 import collections
+from pickletools import int4
 import typing
 import numpy as np
 import dataclasses
 import dataclass_abc
-import scipy.special
+from scipy.special import binom
 import itertools
+import functools
 
 from polymatrix.utils import variable_to_index
 
@@ -30,76 +32,257 @@ class PolyMatrixMixin(abc.ABC):
 
     @property
     @abc.abstractmethod
-    def subs_func(self) -> typing.Callable[[DegreeType, int, int], tuple[float]]:
+    def re_index(self) -> typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]]:
         ...
 
     @property
     @abc.abstractmethod
-    def re_index(self) -> dict[DegreeType, dict[int, int, tuple[int, int, float]]]:
+    def is_constant(self) -> int:
         ...
 
     @property
     @abc.abstractmethod
-    def re_index_func(self) -> typing.Callable[[int, int], tuple[int, int, float]]:
+    def is_vector(self) -> bool:
         ...
 
+
+class EqualityConstraintMixin(abc.ABC):
     @property
     @abc.abstractmethod
-    def re_index_func_2(self) -> typing.Callable[[int, int], tuple[int, int, float]]:
+    def terms(self) -> dict[int, dict[tuple[int, tuple, int], float]]:
         ...
 
+    # @property
+    # @abc.abstractmethod
+    # def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
+    #     ...
+
+    # @property
+    # @abc.abstractmethod
+    # def param_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
+    #     ...
+
     @property
     @abc.abstractmethod
-    def is_vector(self) -> bool:
+    def n_param(self) -> int:
         ...
 
+    @functools.cached_property
+    def eq_to_row_index(self):
+        rows_to_eq = list(set((eq_idx, perm) for eq_tuple_degree in self.terms.values() for (eq_idx, perm, var) in eq_tuple_degree.keys()))
+        eq_to_rows = {eq: idx for idx, eq in enumerate(rows_to_eq)}
+        return eq_to_rows
 
-class EquationMixin(abc.ABC):
+    @property
+    def n_eq(self):
+        return len(self.eq_to_row_index)
+
+    # @property
+    # @abc.abstractmethod
+    # def n_var(self) -> int:
+    #     ...
+
+    def get_constraint_func(self):
+        def func(x):
+            mat = np.zeros((self.n_eq,))
+
+            for degree, degree_tuples in self.terms.items():
+                if 0 == degree:
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        row_idx = self.eq_to_row_index[(idx_eq, perm)]
+                        mat[row_idx] += value
+
+                elif 0 < degree:
+                    def gen_vector():
+                        for indices in itertools.combinations_with_replacement(range(self.n_param), degree):
+                            yield np.prod(list(x[idx] for idx in indices))   
+                    vector = list(gen_vector())
+
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        row_idx = self.eq_to_row_index[(idx_eq, perm)]
+                        vector_val = vector[variable_to_index(self.n_param, variables)]
+                        mat[row_idx] += value * vector_val
+
+            return mat
+        return func
+
+    def get_constraint_jacobian(self):
+        def func(x):
+            jac_mat = np.zeros((self.n_eq, self.n_param))
+
+            for degree, degree_tuples in self.terms.items():
+                if 1 == degree:
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        row_idx = self.eq_to_row_index[(idx_eq, perm)]
+                        jac_mat[row_idx, variables[0]] += value
+
+                        # for var in variables:
+                        #     col_idx = variable_to_index(self.n_param, (var,))
+                        #     jac_mat[row_idx, col_idx] += value
+
+                elif 1 < degree:
+                    def gen_vector():
+                        for indices in itertools.combinations_with_replacement(range(self.n_param), degree-1):
+                            yield np.prod(list(x[idx] for idx in indices))   
+                    vector = list(gen_vector())
+
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        row_idx = self.eq_to_row_index[(idx_eq, perm)]
+
+                        for var_idx, var in enumerate(variables):
+                            other_variables = variables[:var_idx] + variables[var_idx+1:]
+                            vector_val = vector[variable_to_index(self.n_param, other_variables)]
+
+                            # col_idx = variable_to_index(self.n_param, (var,))
+                            jac_mat[row_idx, var] += value*vector_val
+
+            return jac_mat
+        return func
+
+    def get_constraint_hessian(self):
+        def func(x, v):
+            hess_mat = np.zeros((self.n_param, self.n_param))
+
+            for degree, degree_tuples in self.terms.items():
+                if 2 == degree:
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        eq_idx = self.eq_to_row_index[(idx_eq, perm)]
+
+                        for var_idx_x, var_x in enumerate(variables):
+                            other_variables = variables[:var_idx_x] + variables[var_idx_x+1:]
+
+                            for var_idx_y, var_y in enumerate(other_variables):
+                                hess_mat[var_x, var_y] = v[eq_idx]*value
+
+                elif 2 < degree:
+                    def gen_vector():
+                        for indices in itertools.combinations_with_replacement(range(self.n_param), degree-2):
+                            yield np.prod(list(x[idx] for idx in indices))   
+                    vector = list(gen_vector())
+
+                    for (idx_eq, perm, variables), value in degree_tuples.items():
+                        eq_idx = self.eq_to_row_index[(idx_eq, perm)]
+
+                        for var_idx_x, var_x in enumerate(variables):
+                            other_variables = variables[:var_idx_x] + variables[var_idx_x+1:]
+
+                            for var_idx_y, var_y in enumerate(other_variables):
+                                other_variables_2 = variables[:var_idx_y] + variables[var_idx_y+1:]
+                                vector_val = vector[variable_to_index(self.n_param, other_variables_2)]
+                                hess_mat[var_x, var_y] = v[eq_idx]*value*vector_val
+
+            return hess_mat
+        return func
+
+
+class PolyEquationMixin(abc.ABC):
     @property
     @abc.abstractmethod
     def terms(self) -> list[tuple[PolyMatrixMixin, PolyMatrixMixin]]:
+        """
+        the terms the polynomial matrix equation consists of
+        """
+
         ...
 
     @property
     @abc.abstractmethod
     def n_var(self) -> int:
-        ...
+        """
+        number of variables defining the polynomials
 
-    def create(
-        self,
-        subs: dict[PolyMatrixMixin, dict[DegreeType, dict[int, int, float]]] = None,
-    ):
-        if subs is None:
-            added_subs = {}
-        else:
-            added_subs = subs
+        for example n_var=3: x1, x2 and x3
+        """
 
-        binom = scipy.special.binom
+        ...
+
+    @functools.cached_property
+    def _param_list(self) -> list[tuple[PolyMatrixMixin, int], int]:
+        """
+        used to determine the offset of the coefficients of each polynomial matrix
+        """
 
         # create parameter offset
         all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term)
 
         def gen_n_param_per_struct():
-            acc = 0
+            # acc = 0
 
             for struct in all_structs:
 
-                for degree in struct.degrees:
+                if struct.is_constant:
+                    continue
 
-                    yield (struct, degree), acc
+                for degree in struct.degrees:
 
                     if struct.is_vector:
                         number_of_terms = int(self.n_var * binom(self.n_var+degree-1, degree))
                     else:
                         number_of_terms = int(self.n_var**2 * binom(self.n_var+degree-1, degree))
 
-                    acc += number_of_terms
-
-            yield None, acc #, None
+                    yield (struct, degree), number_of_terms
 
         param_list = list(gen_n_param_per_struct())
-        n_param = param_list[-1][1]
-        offset_dict = dict(((e[0], e[1]) for e in param_list[:-1]))
+
+        return param_list
+
+    @functools.cached_property
+    def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
+        """
+        determine the offset of the coefficients of each polynomial matrix ordered by degree
+
+        The polynomial equation
+        
+            A * B = 0
+
+        is represented by a vector of coefficients `coeff`. Each coefficients is associated to polynomial matrix.
+
+        For example
+
+            offset_dict[(A,0)] = 12
+
+        means that the first coefficient a011 (meaning 011=degree+row+col) associated to A and degree 0 is located at index 12 of `coeff`.
+        """
+
+        param_key_value = list(zip(*self._param_list))
+
+        if 0 < len(param_key_value):
+            param_key, param_value = param_key_value
+            cum_sum = list(itertools.accumulate(param_value))
+            offset_dict = dict(zip(param_key, [0] + cum_sum[:-1]))
+        else:
+            offset_dict = {}
+
+        return offset_dict
+
+    @functools.cached_property
+    def n_param(self) -> dict[tuple[PolyMatrixMixin, int], int]:
+        """
+        number of coefficients of polynomial matrix equation, e.g. `len(coeff)`
+        """
+        
+        if 0 < len(self._param_list):
+            *_, n_param = itertools.accumulate(e[1] for e in self._param_list)
+        else:
+            n_param = 0
+
+        return n_param
+
+    def create(
+        self,
+        subs: dict[PolyMatrixMixin, dict[DegreeType, dict[int, int, float]]] = None,
+    ) -> EqualityConstraintMixin:
+        n_var_2 = self.n_var**2
+
+        if subs is None:
+            added_subs = {}
+        else:
+            added_subs = subs
+
+        # binom = scipy.special.binom
+
+        # create parameter offset
+        all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term)
 
         def gen_substitutions():
             for struct in all_structs:
@@ -123,47 +306,12 @@ class EquationMixin(abc.ABC):
 
                     else:
                         all_subs = None
-                
-                def subs_func(degree, idx_eq, idx_col, all_subs=all_subs, struct=struct):
 
-                    if struct.re_index_func is not None:
-                        re_index = struct.re_index_func(degree, idx_eq, idx_col)
-                    else:
-                        re_index = None
-
-                    if re_index is None:
-                        n_coord = (idx_eq, idx_col)
-                        factor = 1
-                    else:
-                        n_coord = (re_index[0], re_index[1])
-                        factor = re_index[2]
-
-                    if all_subs is not None and degree in all_subs:
-                        all_degree_subs = all_subs[degree]
-
-                        if n_coord in all_degree_subs:
-                            subs_val = all_degree_subs[n_coord] * factor
-                        else:
-                            subs_val = None
-
-                    elif struct.subs_func is not None:
-                        sub_value = struct.subs_func(degree, n_coord[0], n_coord[1])
-
-                        if sub_value is not None:
-                            subs_val = sub_value * factor
-                        else:
-                            subs_val = None
-
-                    else:
-                        subs_val = None
-
-                    return n_coord, factor, subs_val
-
-                yield struct, subs_func
+                yield struct, all_subs
 
         subs_dict = dict(gen_substitutions())
 
-        terms = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(int)))
+        terms = collections.defaultdict(lambda: collections.defaultdict(float))
 
         for left, right in self.terms:
 
@@ -172,13 +320,21 @@ class EquationMixin(abc.ABC):
 
             for d1 in left.degrees:
 
-                n_param_1 = binom(self.n_var+d1-1, d1)
-                offset_1 = offset_dict[(left, d1)]
+                if subs_1 is not None and d1 in subs_1:
+                    subs_1_d = subs_1[d1]
+                else:
+                    subs_1_d = None
+
+                offset_1 = self.offset_dict.get((left, d1), 0)
 
                 for d2 in right.degrees:
 
-                    n_param_2 = binom(self.n_var+d2-1, d2)
-                    offset_2 = offset_dict[(right, d2)]
+                    if subs_2 is not None and d2 in subs_2:
+                        subs_2_d = subs_2[d2]
+                    else:
+                        subs_2_d = None
+
+                    offset_2 = self.offset_dict.get((right, d2), 0)
 
                     total_degree = d1 + d2
 
@@ -199,74 +355,150 @@ class EquationMixin(abc.ABC):
                             # (1,0) -> x2*x1 instead of (0,1)->x1*x2
                             if non_increasing(grp1) and non_increasing(grp2):
 
-                                left_idx = variable_to_index(self.n_var, grp1)
-                                d_right_idx = variable_to_index(self.n_var, grp2)
+                                left_col_default = variable_to_index(self.n_var, grp1)
+                                right_col_default = variable_to_index(self.n_var, grp2)
 
                                 # for each column of the poly matrix, and row of the poly vector
-                                for idx_col in range(self.n_var):
+                                for left_poly_col in range(self.n_var):
 
-                                    if right.re_index_func_2 is None:
-                                        re_index_2 = None
+                                    right_poly_row = left_poly_col
+
+                                    if right.re_index is not None:
+                                        re_index_2 = right.re_index(d2, right_poly_row, 0, grp2)
                                     else:
-                                        re_index_2 = right.re_index_func_2(d2, idx_col, grp2)
+                                        re_index_2 = None
 
                                     if re_index_2 is None:
-                                        v_idx_row = idx_col
-                                        factor_22 = 1
-                                        right_idx = d_right_idx
+                                        n_right_poly_row, factor_2, right_col = right_poly_row, 1, right_col_default
                                     else:
-                                        v_idx_row, n_grp2, factor_22 = re_index_2
-                                        right_idx = variable_to_index(self.n_var, n_grp2)
-
-                                    n_coord_2, factor_21, subs_val_2 = subs_2(d2, v_idx_row, 0)
-
-                                    factor_2 = factor_21 * factor_22
+                                        n_right_poly_row, _, n_grp2, factor_2 = re_index_2
+                                        right_col = variable_to_index(self.n_var, n_grp2)
+
+                                    key_2 = (n_right_poly_row, 0, right_col)
+                                    if subs_2_d is not None:
+                                        try:
+                                            subs_val_2 = subs_2_d[key_2]
+                                        except KeyError:
+                                            subs_val_2 = None
+                                    else:
+                                        subs_val_2 = None
 
                                     if factor_2 == 0 or subs_val_2 == 0:
                                         continue
 
-                                    # for each polynomial equation
-                                    for idx_eq in range(self.n_var):
+                                    right_row = n_right_poly_row
+                                    right_param_idx = int(offset_2 + right_row + right_col * self.n_var)
 
-                                        n_coord_1, factor_1, subs_val_1 = subs_1(d1, idx_eq, idx_col)
+                                    # for each polynomial equation
+                                    for left_poly_row in range(self.n_var):
+
+                                        if left.re_index is not None:
+                                            re_index_1 = left.re_index(d1, left_poly_row, left_poly_col, grp1)
+                                        else:
+                                            re_index_1 = None
+
+                                        if re_index_1 is None:
+                                            n_left_poly_row, n_left_poly_col, factor_1 = left_poly_row, left_poly_col, 1
+                                            left_col = left_col_default
+                                        else:
+                                            n_left_poly_row, n_left_poly_col, n_grp1, factor_1 = re_index_1
+                                            left_col = variable_to_index(self.n_var, n_grp1)
+
+                                        key_1 = (n_left_poly_row, n_left_poly_col, left_col)
+                                        if subs_1_d is not None:
+                                            try:
+                                                subs_val_1 = subs_1_d[key_1]
+                                            except KeyError:
+                                                subs_val_1 = None
+                                        else:
+                                            subs_val_1 = None
 
                                         if factor_1 == 0 or subs_val_1 == 0:
                                             continue
-                                        
-                                        left_param_idx = int(offset_1 + left_idx + (self.n_var * n_coord_1[0] + n_coord_1[1]) * n_param_1)
-                                        right_param_idx = int(offset_2 + right_idx + n_coord_2[0] * n_param_2)
+
+                                        left_row = n_left_poly_row + n_left_poly_col * self.n_var
+                                        left_param_idx = int(offset_1 + left_row + left_col * n_var_2)
+
+                                        # print(left_poly_row)
+                                        # print(f'{left_row=}')
+                                        # print(f'{left_param_idx=}')
+
+                                        total_factor = factor_1 * factor_2
 
                                         match (subs_val_1, subs_val_2):
                                             case (None, None):
-                                                col_idx = variable_to_index(n_param, (left_param_idx, right_param_idx))
-                                                # col_idx = (left_param_idx, right_param_idx)
+                                                col_idx = (left_param_idx, right_param_idx)
                                                 degree = 2
-                                                value = factor_1*factor_2
+                                                value = total_factor
 
-                                                terms[1][(idx_eq, perm)][left_param_idx] += 2*value
-                                                terms[1][(idx_eq, perm)][right_param_idx] += 2*value
-                                                terms[0][(idx_eq, perm)][0] += value
+                                                # terms[1][(idx_eq, perm)][left_param_idx] += 2*value
+                                                # terms[1][(idx_eq, perm)][right_param_idx] += 2*value
+                                                # terms[0][(idx_eq, perm)][0] += value
                                             
                                             case (subs_val, None):
-                                                col_idx = right_param_idx
-                                                # col_idx = (right_param_idx,)
+                                                col_idx = (right_param_idx,)
                                                 degree = 1
-                                                value = subs_val_1*factor_2
+                                                value = subs_val_1*total_factor
 
-                                                terms[degree][(idx_eq, perm)][0] += value
+                                                # terms[degree][(idx_eq, perm)][0] += value
 
                                             case (None, subs_val):
-                                                col_idx = left_param_idx
-                                                # col_idx = (left_param_idx,)
+                                                col_idx = (left_param_idx,)
                                                 degree = 1
-                                                value = subs_val_2*factor_1
+                                                value = subs_val_2*total_factor
+
+                                                # terms[degree][(idx_eq, perm)][0] += value
                                             
                                             case _:
-                                                degree, col_idx, value = 0, tuple(), subs_val_1*subs_val_2*factor_1*factor_2
+                                                degree, col_idx, value = 0, tuple(), subs_val_1*subs_val_2*total_factor
+
+                                        terms[degree][left_poly_row, perm, col_idx] += value
+
+        return EqualityConstraintImpl(
+            terms=terms, 
+            n_param=self.n_param,
+        )
 
-                                        terms[degree][(idx_eq, perm)][col_idx] += value
+    def matrix_to_poly(self, struct, x, param, tol=None):
+        assert len(x) == self.n_var, f'variable {x} needs to be of length {self.n_var}'
 
-        return terms, offset_dict, n_param
+        n_var_2 = self.n_var**2
+        
+        if struct.is_vector:
+            n_col = 1
+            
+        else:
+            n_col = self.n_var
+            
+        sym_expr = [[0 for _ in range(n_col)] for _ in range(self.n_var)]
+            
+        for degree in struct.degrees:
+            offset = self.offset_dict[(struct, degree)] 
+            # number_of_terms = int(binom(self.n_var+degree-1, degree))
+            
+            def write_to_expr(row, col, val, term=1):
+                if tol is None or val <= -tol or tol <= val:
+                    sym_expr[row][col] += val*term
+
+            if 0 == degree:
+                for row in range(self.n_var):
+                    for col in range(n_col):
+                        write_to_expr(row, col, param[offset + row + col * self.n_var])
+                        
+            else:
+                def gen_vector():
+                    for comb in itertools.combinations_with_replacement(range(self.n_var), degree):
+                        *_, last = itertools.accumulate(comb, lambda acc, idx: acc*x[idx], initial=1)
+                        yield last
+                vector = list(gen_vector())
+                
+                for row in range(self.n_var):
+                    for col in range(n_col):
+                        for idx, term in enumerate(vector):
+                            # print(f'{offset + (row + col * self.n_var) * number_of_terms + idx=}, {param[offset + (row + col * self.n_var) * number_of_terms + idx]=}')
+                            write_to_expr(row, col, param[offset + row + col * self.n_var + idx * n_var_2], term)
+                            
+        return sym_expr
 
 ########################################
 # Classes
@@ -276,9 +508,14 @@ class PolyMatrix(PolyMatrixMixin):
     pass
 
 
-class Equation(EquationMixin):
+class EqualityConstraint(EqualityConstraintMixin):
+    pass
+
+
+class PolyEquation(PolyEquationMixin):
     pass
 
+
 ########################################
 # Implementations
 ########################################
@@ -287,15 +524,22 @@ class Equation(EquationMixin):
 class PolyMatrixImpl(PolyMatrix):
     degrees: list[int]
     subs: dict[int, dict[tuple[int, int], float]]
-    subs_func: typing.Callable[[int, int, int], tuple[float]]
-    re_index: dict[int, dict[tuple[int, int], tuple[int, int, float]]]
-    re_index_func: typing.Callable[[int, int], tuple[int, int, float]]
-    re_index_func_2: typing.Callable[[int, int, tuple[int, ...]], tuple[int, tuple[int, ...], float]]
+    re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]]
+    is_constant: bool
     is_vector: bool
 
 
 @dataclass_abc.dataclass_abc(frozen=True)
-class EquationImpl(Equation):
+class EqualityConstraintImpl(EqualityConstraintMixin):
+    terms: dict[int, dict[tuple[int, tuple, int], float]]
+    # offset_dict: dict[tuple[PolyMatrixMixin, int], int]
+    # param_dict: dict[tuple[PolyMatrixMixin, int], int]
+    n_param: int
+    # n_var: int
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class EquationImpl(PolyEquation):
     terms: list[tuple[PolyMatrix, PolyMatrix]]
     n_var: int
 
@@ -306,40 +550,34 @@ class EquationImpl(Equation):
 def init_poly_vector(
     degrees: list[int] = None,
     subs: dict[int, dict[tuple[int, int], float]] = None,
-    subs_func: typing.Callable[[int, int, int], tuple[float]] = None,
-    re_index: dict[int, dict[tuple[int, int], tuple[int, int, float]]] = None,
-    re_index_func: typing.Callable[[int, int], tuple[int, int, float]] = None,
-    re_index_func_2: typing.Callable[[int, int], tuple[int, int, float]] = None,
+    re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None,
+    is_constant: bool = None,
 ):
     return init_poly_matrix(
         degrees=degrees,
         subs=subs,
-        subs_func=subs_func,
         re_index=re_index,
-        re_index_func=re_index_func,
         is_vector=True,
-        re_index_func_2=re_index_func_2,
+        is_constant=is_constant,
     )
 
 
 def init_poly_matrix(
     degrees: list[int] = None,
     subs: dict[int, dict[tuple[int, int], float]] = None,
-    subs_func: typing.Callable[[int, int, int], tuple[float]] = None,
-    re_index: dict[int, dict[tuple[int, int], tuple[int, int, float]]] = None,
-    re_index_func: typing.Callable[[int, int], tuple[int, int, float]] = None,
-    re_index_func_2: typing.Callable[[int, int], tuple[int, int, float]] = None,
+    re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None,
     is_vector: bool = None,
+    is_constant: bool = None,
 ):
     if degrees is None:
         assert isinstance(subs, dict)
         degrees = list(subs.keys())
 
-    if subs is None and subs_func is None:
-        subs = {}
-
-    if re_index is None and re_index_func is None:
-            re_index = {}
+    if is_constant is None:
+        if subs is None:
+            is_constant = False
+        else:
+            is_constant = True
 
     if is_vector is None:
         is_vector = False
@@ -347,10 +585,8 @@ def init_poly_matrix(
     return PolyMatrixImpl(
         degrees=degrees,
         subs=subs,
-        subs_func=subs_func,
         re_index = re_index,
-        re_index_func = re_index_func,
-        re_index_func_2 = re_index_func_2,
+        is_constant = is_constant,
         is_vector = is_vector,
     )
 
diff --git a/polymatrix/sympyutils.py b/polymatrix/sympyutils.py
index e2c37d8..6fe7938 100644
--- a/polymatrix/sympyutils.py
+++ b/polymatrix/sympyutils.py
@@ -1,11 +1,13 @@
+import collections
 import itertools
 import numpy as np
 import scipy.sparse
+import sympy
 
 from polymatrix.utils import variable_powers_to_index
 
 
-def poly_to_data_coord(poly_list, power = None):
+def poly_to_data_coord(poly_list, x, degree = None):
     """
     poly_list = [
         poly(x1*x3**2, x)
@@ -13,70 +15,71 @@ def poly_to_data_coord(poly_list, power = None):
     power: up to which power 
     """
 
-    if power is None:
-        power = max(degree for poly in poly_list for degree in poly.degree_list())
+    sympy_poly_list = tuple(tuple(sympy.poly(p, x) for p in inner_poly_list) for inner_poly_list in poly_list)
 
-    def all_equal(iterable):
-        g = itertools.groupby(iterable)
-        return next(g, True) and not next(g, False)
+    if degree is None:
+        degree = max(degree for inner_poly_list in sympy_poly_list for poly in inner_poly_list for degree in poly.degree_list())
 
-    assert all_equal((p.gens for p in poly_list)), 'all polynomials need to have identical generators'
+    # def all_equal(iterable):
+    #     g = itertools.groupby(iterable)
+    #     return next(g, True) and not next(g, False)
+
+    # assert all_equal((p.gens for p in poly_list)), 'all polynomials need to have identical generators'
 
     def gen_power_mat():
 
         # for all powers generate a matrix
-        for current_power in range(power + 1):
+        for current_degree in range(degree + 1):
         
             def gen_value_index():
 
                 # each polynomial defines a row in the matrix
-                for row, p in enumerate(poly_list):
+                for poly_row, inner_poly_list in enumerate(sympy_poly_list):
+
+                    for poly_col, p in enumerate(inner_poly_list):
 
-                    # a5 x1 x3**2 -> c=a5, m=(1, 0, 2)
-                    for c, m in zip(p.coeffs(), p.monoms()):
+                        # a5 x1 x3**2 -> c=a5, m=(1, 0, 2)
+                        for c, m in zip(p.coeffs(), p.monoms()):
 
-                        if sum(m) == current_power:
+                            if sum(m) == current_degree:
 
-                            index = variable_powers_to_index(m)
-                            yield (row, index), c
+                                index = variable_powers_to_index(m)
+                                yield (poly_row, poly_col, index), c
                     
-            # yield list(zip(*gen_value_index()))
-            data = dict(gen_value_index())
+            data = dict(gen_value_index()) | collections.defaultdict(float)
 
             if len(data) > 0:
-                yield current_power, data
+                yield current_degree, data
 
     return dict(gen_power_mat())
 
 
-def poly_to_matrix(poly_list, power = None):
+def poly_to_matrix(poly_list, x, power = None):
     """
     
     """
 
-    data_coord_dict = poly_to_data_coord(poly_list, power)
+    data_coord_dict = poly_to_data_coord(poly_list, x, power)
 
-    n_free_symbols = len(poly_list[0].gens)
+    # n_free_symbols = len(poly_list[0][0].gens)
+    n_free_symbols = len(x)
 
     def gen_power_mat():
 
         # for all powers generate a matrix
-        for current_power, data_coord in data_coord_dict:
+        for current_degree, data_coord in data_coord_dict:
             
             # empty matrix
-            shape = (len(poly_list), n_free_symbols**current_power)
+            shape = (len(poly_list), n_free_symbols**current_degree)
             # m = np.zeros((len(poly_list), n_free_symbols**current_power))
 
             # fill matrix
             if len(data_coord) == 0:
-                yield np.zeros((len(poly_list), n_free_symbols**current_power))
+                yield np.zeros((len(poly_list), n_free_symbols**current_degree))
 
             else:
                 rows, cols, data = list(zip(*data_coord))
 
                 yield scipy.sparse.coo_matrix((data, (rows, cols)), dtype=np.double, shape=shape)
-
-                # m[rows, cols] = data
-                # yield m
             
     return list(gen_power_mat())
-- 
cgit v1.2.1