From eda5bc26f44ee9a6f83dcf8c91f17296d7fc509d Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 12 Feb 2024 14:52:43 +0100 Subject: Move into version control --- src/armadillo/include/armadillo_bits/fn_svds.hpp | 352 +++++++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 src/armadillo/include/armadillo_bits/fn_svds.hpp (limited to 'src/armadillo/include/armadillo_bits/fn_svds.hpp') diff --git a/src/armadillo/include/armadillo_bits/fn_svds.hpp b/src/armadillo/include/armadillo_bits/fn_svds.hpp new file mode 100644 index 0000000..26c8c50 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_svds.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_svds +//! @{ + + +template +inline +bool +svds_helper + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol, + const bool calc_UV, + const typename arma_real_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svds(): two or more output objects are the same object" + ); + + arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); + + const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); + + if(A_max == T(0)) + { + // TODO: use reset instead ? + S.zeros(kk); + + if(calc_UV) + { + U.eye(A.n_rows, kk); + V.eye(A.n_cols, kk); + } + } + else + { + SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); + + SpMat B = A / A_max; + SpMat Bt = B.t(); + + C(0, A.n_rows, arma::size(B) ) = B; + C(A.n_rows, 0, arma::size(Bt)) = Bt; + + Bt.reset(); + B.reset(); + + Col eigval; + Mat eigvec; + + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_sym(eigval, eigvec, C, kk, "la", opts); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + + return false; + } + + const T A_norm = max(eigval); + + const T tol2 = tol / Datum::sqrt2 * A_norm; + + uvec indices = find(eigval > tol2); + + if(indices.n_elem > kk) + { + indices = indices.subvec(0,kk-1); + } + else + if(indices.n_elem < kk) + { + const uvec indices2 = find(abs(eigval) <= tol2); + + const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); + + if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } + } + + const uvec sorted_indices = sort_index(eigval, "descend"); + + S = eigval.elem(sorted_indices); S *= A_max; + + if(calc_UV) + { + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + + U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); + V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); + } + } + + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } + + return true; + } + + + +template +inline +bool +svds_helper + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol, + const bool calc_UV, + const typename arma_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(arma_config::arpack == false) + { + arma_stop_logic_error("svds(): use of ARPACK must be enabled for decomposition of complex matrices"); + return false; + } + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svds(): two or more output objects are the same object" + ); + + arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); + + const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); + + if(A_max == T(0)) + { + // TODO: use reset instead ? + S.zeros(kk); + + if(calc_UV) + { + U.eye(A.n_rows, kk); + V.eye(A.n_cols, kk); + } + } + else + { + SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); + + SpMat B = A / A_max; + SpMat Bt = B.t(); + + C(0, A.n_rows, arma::size(B) ) = B; + C(A.n_rows, 0, arma::size(Bt)) = Bt; + + Bt.reset(); + B.reset(); + + Col eigval_tmp; + Mat eigvec; + + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_gen(eigval_tmp, eigvec, C, kk, "lr", opts); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + + return false; + } + + const Col eigval = real(eigval_tmp); + + const T A_norm = max(eigval); + + const T tol2 = tol / Datum::sqrt2 * A_norm; + + uvec indices = find(eigval > tol2); + + if(indices.n_elem > kk) + { + indices = indices.subvec(0,kk-1); + } + else + if(indices.n_elem < kk) + { + const uvec indices2 = find(abs(eigval) <= tol2); + + const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); + + if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } + } + + const uvec sorted_indices = sort_index(eigval, "descend"); + + S = eigval.elem(sorted_indices); S *= A_max; + + if(calc_UV) + { + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + + U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); + V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); + } + } + + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } + + return true; + } + + + +//! find the k largest singular values and corresponding singular vectors of sparse matrix X +template +inline +bool +svds + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true); + + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } + + return status; + } + + + +//! find the k largest singular values of sparse matrix X +template +inline +bool +svds + ( + Col& S, + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + Mat U; + Mat V; + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); + + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } + + return status; + } + + + +//! find the k largest singular values of sparse matrix X +template +arma_warn_unused +inline +Col +svds + ( + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + Col S; + + Mat U; + Mat V; + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); + + if(status == false) { arma_stop_runtime_error("svds(): decomposition failed"); } + + return S; + } + + + +//! @} -- cgit v1.2.1