// SPDX-License-Identifier: Apache-2.0 // // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ------------------------------------------------------------------------ //! \addtogroup op_inv_spd //! @{ template inline void op_inv_spd_default::apply(Mat& out, const Op& X) { arma_extra_debug_sigprint(); const bool status = op_inv_spd_default::apply_direct(out, X.m); if(status == false) { out.soft_reset(); arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); } } template inline bool op_inv_spd_default::apply_direct(Mat& out, const Base& expr) { arma_extra_debug_sigprint(); return op_inv_spd_full::apply_direct(out, expr, uword(0)); } // template inline void op_inv_spd_full::apply(Mat& out, const Op& X) { arma_extra_debug_sigprint(); const uword flags = X.aux_uword_a; const bool status = op_inv_spd_full::apply_direct(out, X.m, flags); if(status == false) { out.soft_reset(); arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); } } template inline bool op_inv_spd_full::apply_direct(Mat& out, const Base& expr, const uword flags) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; if(has_user_flags == true ) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = true"); } if(has_user_flags == false) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = false"); } const bool fast = has_user_flags && bool(flags & inv_opts::flag_fast ); const bool allow_approx = has_user_flags && bool(flags & inv_opts::flag_allow_approx); const bool no_ugly = has_user_flags && bool(flags & inv_opts::flag_no_ugly ); if(has_user_flags) { arma_extra_debug_print("op_inv_spd_full: enabled flags:"); if(fast ) { arma_extra_debug_print("fast"); } if(allow_approx) { arma_extra_debug_print("allow_approx"); } if(no_ugly ) { arma_extra_debug_print("no_ugly"); } arma_debug_check( (fast && allow_approx), "inv_sympd(): options 'fast' and 'allow_approx' are mutually exclusive" ); arma_debug_check( (fast && no_ugly ), "inv_sympd(): options 'fast' and 'no_ugly' are mutually exclusive" ); arma_debug_check( (no_ugly && allow_approx), "inv_sympd(): options 'no_ugly' and 'allow_approx' are mutually exclusive" ); } if(no_ugly) { op_inv_spd_state inv_state; const bool status = op_inv_spd_rcond::apply_direct(out, inv_state, expr); // workaround for bug in gcc 4.8 const uword local_size = inv_state.size; const T local_rcond = inv_state.rcond; if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { return false; } return true; } if(allow_approx) { op_inv_spd_state inv_state; Mat tmp; const bool status = op_inv_spd_rcond::apply_direct(tmp, inv_state, expr); // workaround for bug in gcc 4.8 const uword local_size = inv_state.size; const T local_rcond = inv_state.rcond; if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { const Mat A = expr.get_ref(); if(inv_state.is_diag) { return op_pinv::apply_diag(out, A, T(0)); } return op_pinv::apply_sym(out, A, T(0), uword(0)); } out.steal_mem(tmp); return true; } out = expr.get_ref(); arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); if((arma_config::debug) && (arma_config::warn_level > 0)) { if(auxlib::rudimentary_sym_check(out) == false) { if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } } else if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) { arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); } } const uword N = out.n_rows; if(N == 0) { return true; } if(is_cx::no) { if(N == 1) { const T a = access::tmp_real(out[0]); out[0] = eT(T(1) / a); return (a > T(0)); } else if(N == 2) { const bool status = op_inv_spd_full::apply_tiny_2x2(out); if(status) { return true; } } // fallthrough if optimisation failed } if(is_op_diagmat::value || out.is_diagmat()) { arma_extra_debug_print("op_inv_spd_full: detected diagonal matrix"); eT* colmem = out.memptr(); for(uword i=0; i inline bool op_inv_spd_full::apply_tiny_2x2(Mat& X) { arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; // NOTE: assuming matrix X is square sized // NOTE: assuming matrix X is symmetric // NOTE: assuming matrix X is real constexpr T det_min = std::numeric_limits::epsilon(); constexpr T det_max = T(1) / std::numeric_limits::epsilon(); eT* Xm = X.memptr(); T a = access::tmp_real(Xm[0]); T c = access::tmp_real(Xm[1]); T d = access::tmp_real(Xm[3]); const T det_val = (a*d - c*c); // positive definite iff all leading principal minors are positive // a = first leading principal minor (top-left 1x1 submatrix) // det_val = second leading principal minor (top-left 2x2 submatrix) if(a <= T(0)) { return false; } // NOTE: since det_min is positive, this also checks whether det_val is positive if((det_val < det_min) || (det_val > det_max) || arma_isnan(det_val)) { return false; } d /= det_val; c /= det_val; a /= det_val; Xm[0] = d; Xm[1] = -c; Xm[2] = -c; Xm[3] = a; return true; } // template inline bool op_inv_spd_rcond::apply_direct(Mat& out, op_inv_spd_state& out_state, const Base& expr) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; out = expr.get_ref(); out_state.size = out.n_rows; out_state.rcond = T(0); arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); if((arma_config::debug) && (arma_config::warn_level > 0)) { if(auxlib::rudimentary_sym_check(out) == false) { if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } } else if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) { arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); } } if(is_op_diagmat::value || out.is_diagmat()) { arma_extra_debug_print("op_inv_spd_rcond: detected diagonal matrix"); out_state.is_diag = true; eT* colmem = out.memptr(); T max_abs_src_val = T(0); T max_abs_inv_val = T(0); const uword N = out.n_rows; for(uword i=0; i max_abs_src_val) ? abs_src_val : max_abs_src_val; max_abs_inv_val = (abs_inv_val > max_abs_inv_val) ? abs_inv_val : max_abs_inv_val; colmem += N; } out_state.rcond = T(1) / (max_abs_src_val * max_abs_inv_val); return true; } if(auxlib::crippled_lapack(out)) { arma_extra_debug_print("op_inv_spd_rcond: workaround for crippled lapack"); Mat tmp = out; bool sympd_state = false; auxlib::inv_sympd(out, sympd_state); if(sympd_state == false) { out.soft_reset(); out_state.rcond = T(0); return false; } out_state.rcond = auxlib::rcond(tmp); if(out_state.rcond == T(0)) { out.soft_reset(); return false; } return true; } bool is_sympd_junk = false; return auxlib::inv_sympd_rcond(out, is_sympd_junk, out_state.rcond); } //! @}