From eda5bc26f44ee9a6f83dcf8c91f17296d7fc509d Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 12 Feb 2024 14:52:43 +0100 Subject: Move into version control --- .../include/armadillo_bits/op_sqrtmat_meat.hpp | 549 +++++++++++++++++++++ 1 file changed, 549 insertions(+) create mode 100644 src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp (limited to 'src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp') diff --git a/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp new file mode 100644 index 0000000..3c2fae5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp @@ -0,0 +1,549 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sqrtmat +//! @{ + + +//! implementation partly based on: +//! N. J. Higham. +//! A New sqrtm for Matlab. +//! Numerical Analysis Report No. 336, January 1999. +//! Department of Mathematics, University of Manchester. +//! ISSN 1360-1725 +//! http://www.maths.manchester.ac.uk/~higham/narep/narep336.ps.gz + + +template +inline +void +op_sqrtmat::apply(Mat< std::complex >& out, const mtOp,T1,op_sqrtmat>& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat::apply_direct(out, in.m); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + } + + + +template +inline +bool +op_sqrtmat::apply_direct(Mat< std::complex >& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + const diagmat_proxy P(expr.m); + + arma_debug_check( (P.n_rows != P.n_cols), "sqrtmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); + + bool singular = false; + + for(uword i=0; i= T(0)) + { + singular = (singular || (val == T(0))); + + out.at(i,i) = std::sqrt(val); + } + else + { + out.at(i,i) = std::sqrt( std::complex(val) ); + } + } + + return (singular) ? false : true; + } + + + +template +inline +bool +op_sqrtmat::apply_direct(Mat< std::complex >& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_T; + typedef typename std::complex out_T; + + const quasi_unwrap expr_unwrap(expr.get_ref()); + const Mat& A = expr_unwrap.M; + + arma_debug_check( (A.is_square() == false), "sqrtmat(): given matrix must be square sized" ); + + if(A.n_elem == 0) + { + out.reset(); + return true; + } + else + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt( std::complex( A[0] ) ); + return true; + } + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat: detected diagonal matrix"); + + const uword N = A.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_sqrtmat is defined as cx_mat = op(mat) + + for(uword i=0; i= in_T(0)) + { + out.at(i,i) = std::sqrt(val); + } + else + { + out.at(i,i) = std::sqrt( out_T(val) ); + } + } + + return true; + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + + if(try_sympd) + { + arma_extra_debug_print("op_sqrtmat: attempting sympd optimisation"); + + // if matrix A is sympd, all its eigenvalues are positive + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const in_T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >::from( eigvec * diagmat(eigval) * eigvec.t() ); + + return true; + } + } + + arma_extra_debug_print("op_sqrtmat: sympd optimisation failed"); + + // fallthrough if eigen decomposition failed or an eigenvalue is <= 0 + } + + + Mat U; + Mat S(A.n_rows, A.n_cols, arma_nozeros_indicator()); + + const in_T* Amem = A.memptr(); + out_T* Smem = S.memptr(); + + const uword n_elem = A.n_elem; + + for(uword i=0; i( Amem[i] ); + } + + const bool schur_ok = auxlib::schur(U,S); + + if(schur_ok == false) + { + arma_extra_debug_print("sqrtmat(): schur decomposition failed"); + out.soft_reset(); + return false; + } + + const bool status = op_sqrtmat_cx::helper(S); + + const Mat X = U*S; + + S.reset(); + + out = X*U.t(); + + return status; + } + + + +template +inline +void +op_sqrtmat_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_cx::apply_direct(out, in.m); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + } + + + +template +inline +bool +op_sqrtmat_cx::apply_direct(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy P(expr.m); + + bool status = false; + + if(P.is_alias(out)) + { + Mat tmp; + + status = op_sqrtmat_cx::apply_direct_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + status = op_sqrtmat_cx::apply_direct_noalias(out, P); + } + + return status; + } + + + +template +inline +bool +op_sqrtmat_cx::apply_direct_noalias(Mat& out, const diagmat_proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (P.n_rows != P.n_cols), "sqrtmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); + + const eT zero = eT(0); + + bool singular = false; + + for(uword i=0; i +inline +bool +op_sqrtmat_cx::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + Mat U; + Mat S = expr.get_ref(); + + arma_debug_check( (S.n_rows != S.n_cols), "sqrtmat(): given matrix must be square sized" ); + + if(S.n_elem == 0) + { + out.reset(); + return true; + } + else + if(S.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt(S[0]); + return true; + } + + if(S.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_cx: detected diagonal matrix"); + + const uword N = S.n_rows; + + out.zeros(N,N); // aliasing can't happen as S is generated + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, S, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i X = U*S; + + S.reset(); + + out = X*U.t(); + + return status; + } + + + +template +inline +bool +op_sqrtmat_cx::helper(Mat< std::complex >& S) + { + typedef typename std::complex eT; + + if(S.is_empty()) { return true; } + + const uword N = S.n_rows; + + const eT zero = eT(0); + + eT& S_00 = S[0]; + + bool singular = (S_00 == zero); + + S_00 = std::sqrt(S_00); + + for(uword j=1; j < N; ++j) + { + eT* S_j = S.colptr(j); + + eT& S_jj = S_j[j]; + + singular = (singular || (S_jj == zero)); + + S_jj = std::sqrt(S_jj); + + for(uword ii=0; ii <= (j-1); ++ii) + { + const uword i = (j-1) - ii; + + const eT* S_i = S.colptr(i); + + //S_j[i] /= (S_i[i] + S_j[j]); + S_j[i] /= (S_i[i] + S_jj); + + for(uword k=0; k < i; ++k) + { + S_j[k] -= S_i[k] * S_j[i]; + } + } + } + + return (singular) ? false : true; + } + + + +template +inline +void +op_sqrtmat_sympd::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_sympd::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("sqrtmat_sympd(): transformation failed"); + } + } + + + +template +inline +bool +op_sqrtmat_sympd::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap U(expr.get_ref()); + const Mat& X = U.M; + + arma_debug_check( (X.is_square() == false), "sqrtmat_sympd(): given matrix must be square sized" ); + + if((arma_config::debug) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "sqrtmat_sympd(): imaginary components on the diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_sympd: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; + Mat eigvec; + + const bool status = eig_sym_helper(eigval, eigvec, X, 'd', "sqrtmat_sympd()"); + + if(status == false) { return false; } + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i