From 0a0020770513c517ca81eb2edbbaba342078efd2 Mon Sep 17 00:00:00 2001 From: Michael Schneeberger Date: Mon, 17 Jan 2022 15:36:49 +0100 Subject: add skew-symmetry and gradient function --- test_polymatrix/test_polymatrix.py | 67 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) (limited to 'test_polymatrix/test_polymatrix.py') diff --git a/test_polymatrix/test_polymatrix.py b/test_polymatrix/test_polymatrix.py index 1f0bb73..3b239bf 100644 --- a/test_polymatrix/test_polymatrix.py +++ b/test_polymatrix/test_polymatrix.py @@ -135,12 +135,12 @@ class TestPolyMatrix(unittest.TestCase): value = 2, ) - def test_param_matrix_const_vector_skew_symmetric(self): + def test_skew_symmetric_param_matrix_const_vector(self): """ param = [a11 a21 a31 a41] """ - def skew_symmetric(idx1, idx2): + def skew_symmetric(degree, idx1, idx2): if idx1 == idx2: return idx1, idx2, 0 elif idx2 < idx1: @@ -173,4 +173,65 @@ class TestPolyMatrix(unittest.TestCase): eq_idx = 1, row_idx = offset_dict[(mat, 0)] + 1, value = -1, - ) \ No newline at end of file + ) + + def test_const_matrix_param_gradient_vector(self): + """ + param = [v11 v12 v21 v22] + """ + + def gradient(degree, v_row, monom): + if degree == 1: + factor = sum(v_row==e for e in monom) + 1 + + if monom[-1] < v_row: + n_v_row = monom[-1] + n_monom = sorted(monom + (v_row,), reverse=True) + + if v_row <= monom[-1]: + n_v_row = v_row + n_monom = monom + + return n_v_row, n_monom, factor + + mat = init_poly_matrix( + subs={0: {(0, 0): 1, (1, 0): 1, (0, 1): 1, (1, 1): 1}}, + ) + vec = init_poly_vector( + degrees=(1,), + re_index_func_2=gradient, + ) + + eq = init_equation( + terms = [(mat, vec)], + n_var = 2, + ) + + terms, offset_dict = list(eq.create()) + + self.assert_term_in_eq( + terms = terms, + degree = 1, + monoms=(0,), + eq_idx = 0, + row_idx = offset_dict[(vec, 1)], + value = 2, + ) + + self.assert_term_in_eq( + terms = terms, + degree = 1, + monoms=(1,), + eq_idx = 0, + row_idx = offset_dict[(vec, 1)]+1, + value = 1, + ) + + self.assert_term_in_eq( + terms = terms, + degree = 1, + monoms=(0,), + eq_idx = 1, + row_idx = offset_dict[(vec, 1)]+1, + value = 1, + ) -- cgit v1.2.1