// SPDX-License-Identifier: Apache-2.0 // // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ------------------------------------------------------------------------ //! \addtogroup fn_svds //! @{ template inline bool svds_helper ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol, const bool calc_UV, const typename arma_real_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), "svds(): two or more output objects are the same object" ); arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); const unwrap_spmat tmp(X.get_ref()); const SpMat& A = tmp.M; const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); if(A_max == T(0)) { // TODO: use reset instead ? S.zeros(kk); if(calc_UV) { U.eye(A.n_rows, kk); V.eye(A.n_cols, kk); } } else { SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); SpMat B = A / A_max; SpMat Bt = B.t(); C(0, A.n_rows, arma::size(B) ) = B; C(A.n_rows, 0, arma::size(Bt)) = Bt; Bt.reset(); B.reset(); Col eigval; Mat eigvec; eigs_opts opts; opts.tol = (tol / Datum::sqrt2); const bool status = eigs_sym(eigval, eigvec, C, kk, "la", opts); if(status == false) { U.soft_reset(); S.soft_reset(); V.soft_reset(); return false; } const T A_norm = max(eigval); const T tol2 = tol / Datum::sqrt2 * A_norm; uvec indices = find(eigval > tol2); if(indices.n_elem > kk) { indices = indices.subvec(0,kk-1); } else if(indices.n_elem < kk) { const uvec indices2 = find(abs(eigval) <= tol2); const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } } const uvec sorted_indices = sort_index(eigval, "descend"); S = eigval.elem(sorted_indices); S *= A_max; if(calc_UV) { uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } return true; } template inline bool svds_helper ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol, const bool calc_UV, const typename arma_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; if(arma_config::arpack == false) { arma_stop_logic_error("svds(): use of ARPACK must be enabled for decomposition of complex matrices"); return false; } arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), "svds(): two or more output objects are the same object" ); arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); const unwrap_spmat tmp(X.get_ref()); const SpMat& A = tmp.M; const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); if(A_max == T(0)) { // TODO: use reset instead ? S.zeros(kk); if(calc_UV) { U.eye(A.n_rows, kk); V.eye(A.n_cols, kk); } } else { SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); SpMat B = A / A_max; SpMat Bt = B.t(); C(0, A.n_rows, arma::size(B) ) = B; C(A.n_rows, 0, arma::size(Bt)) = Bt; Bt.reset(); B.reset(); Col eigval_tmp; Mat eigvec; eigs_opts opts; opts.tol = (tol / Datum::sqrt2); const bool status = eigs_gen(eigval_tmp, eigvec, C, kk, "lr", opts); if(status == false) { U.soft_reset(); S.soft_reset(); V.soft_reset(); return false; } const Col eigval = real(eigval_tmp); const T A_norm = max(eigval); const T tol2 = tol / Datum::sqrt2 * A_norm; uvec indices = find(eigval > tol2); if(indices.n_elem > kk) { indices = indices.subvec(0,kk-1); } else if(indices.n_elem < kk) { const uvec indices2 = find(abs(eigval) <= tol2); const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } } const uvec sorted_indices = sort_index(eigval, "descend"); S = eigval.elem(sorted_indices); S *= A_max; if(calc_UV) { uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } return true; } //! find the k largest singular values and corresponding singular vectors of sparse matrix X template inline bool svds ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true); if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } return status; } //! find the k largest singular values of sparse matrix X template inline bool svds ( Col& S, const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); Mat U; Mat V; const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } return status; } //! find the k largest singular values of sparse matrix X template arma_warn_unused inline Col svds ( const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); Col S; Mat U; Mat V; const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); if(status == false) { arma_stop_runtime_error("svds(): decomposition failed"); } return S; } //! @}