// SPDX-License-Identifier: Apache-2.0 // // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ------------------------------------------------------------------------ //! \addtogroup op_dot //! @{ //! for two arrays, generic version for non-complex values template arma_inline typename arma_not_cx::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); #if defined(__FAST_MATH__) { eT val = eT(0); for(uword i=0; i inline typename arma_cx_only::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i& X = A[i]; const std::complex& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex(val_real, val_imag); } //! for two arrays, float and double version template inline typename arma_real_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); if( n_elem <= 32u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_extra_debug_print("atlas::cblas_dot()"); return atlas::cblas_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, complex version template inline typename arma_cx_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { if( n_elem <= 16u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_extra_debug_print("atlas::cblas_cx_dot()"); return atlas::cblas_cx_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, integral version template inline typename arma_integral_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { return op_dot::direct_dot_arma(n_elem, A, B); } //! for three arrays template inline eT op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) { arma_extra_debug_sigprint(); eT val = eT(0); for(uword i=0; i inline typename T1::elem_type op_dot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); const bool use_at = (Proxy::use_at) || (Proxy::use_at); const bool have_direct_mem = (quasi_unwrap::has_orig_mem) && (quasi_unwrap::has_orig_mem); if(use_at || have_direct_mem) { const quasi_unwrap A(X); const quasi_unwrap B(Y); arma_debug_check( (A.M.n_elem != B.M.n_elem), "dot(): objects must have the same number of elements" ); return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr()); } else { if(is_subview_row::value && is_subview_row::value) { typedef typename T1::elem_type eT; const subview_row& A = reinterpret_cast< const subview_row& >(X); const subview_row& B = reinterpret_cast< const subview_row& >(Y); if( (A.m.n_rows == 1) && (B.m.n_rows == 1) ) { arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); const eT* A_mem = A.m.memptr(); const eT* B_mem = B.m.memptr(); return op_dot::direct_dot(A.n_elem, &A_mem[A.aux_col1], &B_mem[B.aux_col1]); } } const Proxy PA(X); const Proxy PB(Y); arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); if(is_Mat::stored_type>::value && is_Mat::stored_type>::value) { const quasi_unwrap::stored_type> A(PA.Q); const quasi_unwrap::stored_type> B(PB.Q); return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr()); } return op_dot::apply_proxy(PA,PB); } } template inline typename arma_not_cx::result op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); eT val1 = eT(0); eT val2 = eT(0); uword i,j; for(i=0, j=1; j inline typename arma_cx_only::result op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type::result T; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i xx = A[i]; const std::complex yy = B[i]; const T a = xx.real(); const T b = xx.imag(); const T c = yy.real(); const T d = yy.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex(val_real, val_imag); } // // op_norm_dot template inline typename T1::elem_type op_norm_dot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; const quasi_unwrap tmp1(X); const quasi_unwrap tmp2(Y); const Col A( const_cast(tmp1.M.memptr()), tmp1.M.n_elem, false ); const Col B( const_cast(tmp2.M.memptr()), tmp2.M.n_elem, false ); arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); const T denom = norm(A,2) * norm(B,2); return (denom != T(0)) ? ( op_dot::apply(A,B) / denom ) : eT(0); } // // op_cdot template inline eT op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i& X = A[i]; const std::complex& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex(val_real, val_imag); } template inline eT op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); if( n_elem <= 32u ) { return op_cdot::direct_cdot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::gemv()"); // using gemv() workaround due to compatibility issues with cdotc() and zdotc() const char trans = 'C'; const blas_int m = blas_int(n_elem); const blas_int n = 1; //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); const blas_int inc = 1; const eT alpha = eT(1); const eT beta = eT(0); eT result[2]; // paranoia: using two elements instead of one //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc); blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc); return result[0]; } #else { return op_cdot::direct_cdot_arma(n_elem, A, B); } #endif } } template inline typename T1::elem_type op_cdot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); if(is_Mat::value && is_Mat::value) { return op_cdot::apply_unwrap(X,Y); } else { return op_cdot::apply_proxy(X,Y); } } template inline typename T1::elem_type op_cdot::apply_unwrap(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const unwrap tmp1(X); const unwrap tmp2(Y); const Mat& A = tmp1.M; const Mat& B = tmp2.M; arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" ); return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); } template inline typename T1::elem_type op_cdot::apply_proxy(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type::result T; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; const bool use_at = (Proxy::use_at) || (Proxy::use_at); if(use_at == false) { const Proxy PA(X); const Proxy PB(Y); const uword N = PA.get_n_elem(); arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" ); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i AA = A[i]; const std::complex BB = B[i]; const T a = AA.real(); const T b = AA.imag(); const T c = BB.real(); const T d = BB.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex(val_real, val_imag); } else { return op_cdot::apply_unwrap( X, Y ); } } template inline typename promote_type::result op_dot_mixed::apply(const T1& A, const T2& B) { arma_extra_debug_sigprint(); typedef typename T1::elem_type in_eT1; typedef typename T2::elem_type in_eT2; typedef typename promote_type::result out_eT; const Proxy PA(A); const Proxy PB(B); const uword N = PA.get_n_elem(); arma_debug_check( (N != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); out_eT acc = out_eT(0); for(uword i=0; i < N; ++i) { acc += upgrade_val::apply(PA[i]) * upgrade_val::apply(PB[i]); } return acc; } //! @}