// SPDX-License-Identifier: Apache-2.0 // // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ------------------------------------------------------------------------ //! \addtogroup gmm_diag //! @{ namespace gmm_priv { template inline gmm_diag::~gmm_diag() { arma_extra_debug_sigprint_this(this); arma_type_check(( (is_same_type::value == false) && (is_same_type::value == false) )); } template inline gmm_diag::gmm_diag() { arma_extra_debug_sigprint_this(this); } template inline gmm_diag::gmm_diag(const gmm_diag& x) { arma_extra_debug_sigprint_this(this); init(x); } template inline gmm_diag& gmm_diag::operator=(const gmm_diag& x) { arma_extra_debug_sigprint(); init(x); return *this; } template inline gmm_diag::gmm_diag(const gmm_full& x) { arma_extra_debug_sigprint_this(this); init(x); } template inline gmm_diag& gmm_diag::operator=(const gmm_full& x) { arma_extra_debug_sigprint(); init(x); return *this; } template inline gmm_diag::gmm_diag(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint_this(this); init(in_n_dims, in_n_gaus); } template inline void gmm_diag::reset() { arma_extra_debug_sigprint(); init(0, 0); } template inline void gmm_diag::reset(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint(); init(in_n_dims, in_n_gaus); } template template inline void gmm_diag::set_params(const Base& in_means_expr, const Base& in_dcovs_expr, const Base& in_hefts_expr) { arma_extra_debug_sigprint(); const unwrap tmp1(in_means_expr.get_ref()); const unwrap tmp2(in_dcovs_expr.get_ref()); const unwrap tmp3(in_hefts_expr.get_ref()); const Mat& in_means = tmp1.M; const Mat& in_dcovs = tmp2.M; const Mat& in_hefts = tmp3.M; arma_debug_check ( (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1), "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes" ); arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_params(): given means have non-finite values" ); arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_params(): given dcovs have non-finite values" ); arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_params(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" ); const eT s = accu(in_hefts); arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" ); access::rw(means) = in_means; access::rw(dcovs) = in_dcovs; access::rw(hefts) = in_hefts; init_constants(); } template template inline void gmm_diag::set_means(const Base& in_means_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_means_expr.get_ref()); const Mat& in_means = tmp.M; arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" ); arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_means(): given means have non-finite values" ); access::rw(means) = in_means; } template template inline void gmm_diag::set_dcovs(const Base& in_dcovs_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_dcovs_expr.get_ref()); const Mat& in_dcovs = tmp.M; arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" ); arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_dcovs(): given dcovs have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" ); access::rw(dcovs) = in_dcovs; init_constants(); } template template inline void gmm_diag::set_hefts(const Base& in_hefts_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_hefts_expr.get_ref()); const Mat& in_hefts = tmp.M; arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" ); arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_hefts(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" ); const eT s = accu(in_hefts); arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" ); // make sure all hefts are positive and non-zero const eT* in_hefts_mem = in_hefts.memptr(); eT* hefts_mem = access::rw(hefts).memptr(); for(uword i=0; i < hefts.n_elem; ++i) { hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits::min() ); } access::rw(hefts) /= accu(hefts); log_hefts = log(hefts); } template inline uword gmm_diag::n_dims() const { return means.n_rows; } template inline uword gmm_diag::n_gaus() const { return means.n_cols; } template inline bool gmm_diag::load(const std::string name) { arma_extra_debug_sigprint(); Cube Q; bool status = Q.load(name, arma_binary); if( (status == false) || (Q.n_slices != 2) ) { reset(); arma_debug_warn_level(3, "gmm_diag::load(): problem with loading or incompatible format"); return false; } if( (Q.n_rows < 2) || (Q.n_cols < 1) ) { reset(); return true; } access::rw(hefts) = Q.slice(0).row(0); access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1); access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1); init_constants(); return true; } template inline bool gmm_diag::save(const std::string name) const { arma_extra_debug_sigprint(); Cube Q(means.n_rows + 1, means.n_cols, 2, arma_nozeros_indicator()); if(Q.n_elem > 0) { Q.slice(0).row(0) = hefts; Q.slice(1).row(0).zeros(); // reserved for future use Q.slice(0).submat(1, 0, arma::size(means)) = means; Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs; } const bool status = Q.save(name, arma_binary); return status; } template inline Col gmm_diag::generate() const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Col out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn ); if(N_gaus > 0) { const double val = randu(); double csum = double(0); uword gaus_id = 0; for(uword j=0; j < N_gaus; ++j) { csum += hefts[j]; if(val <= csum) { gaus_id = j; break; } } out %= sqrt(dcovs.col(gaus_id)); out += means.col(gaus_id); } return out; } template inline Mat gmm_diag::generate(const uword N_vec) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn ); if(N_gaus > 0) { const eT* hefts_mem = hefts.memptr(); const Mat sqrt_dcovs = sqrt(dcovs); for(uword i=0; i < N_vec; ++i) { const double val = randu(); double csum = double(0); uword gaus_id = 0; for(uword j=0; j < N_gaus; ++j) { csum += hefts_mem[j]; if(val <= csum) { gaus_id = j; break; } } subview_col out_col = out.col(i); out_col %= sqrt_dcovs.col(gaus_id); out_col += means.col(gaus_id); } } return out; } template template inline eT gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); const quasi_unwrap tmp(expr); arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); return internal_scalar_log_p( tmp.M.memptr() ); } template template inline eT gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk2); const quasi_unwrap tmp(expr); arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" ); return internal_scalar_log_p( tmp.M.memptr(), gaus_id ); } template template inline Row gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); const quasi_unwrap tmp(expr); const Mat& X = tmp.M; return internal_vec_log_p(X); } template template inline Row gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk2); const quasi_unwrap tmp(expr); const Mat& X = tmp.M; return internal_vec_log_p(X, gaus_id); } template template inline eT gmm_diag::sum_log_p(const Base& expr) const { arma_extra_debug_sigprint(); const quasi_unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_sum_log_p(X); } template template inline eT gmm_diag::sum_log_p(const Base& expr, const uword gaus_id) const { arma_extra_debug_sigprint(); const quasi_unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_sum_log_p(X, gaus_id); } template template inline eT gmm_diag::avg_log_p(const Base& expr) const { arma_extra_debug_sigprint(); const quasi_unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_avg_log_p(X); } template template inline eT gmm_diag::avg_log_p(const Base& expr, const uword gaus_id) const { arma_extra_debug_sigprint(); const quasi_unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_avg_log_p(X, gaus_id); } template template inline uword gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk) const { arma_extra_debug_sigprint(); arma_ignore(junk); const quasi_unwrap tmp(expr); const Mat& X = tmp.M; return internal_scalar_assign(X, dist); } template template inline urowvec gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk) const { arma_extra_debug_sigprint(); arma_ignore(junk); urowvec out; const quasi_unwrap tmp(expr); const Mat& X = tmp.M; internal_vec_assign(out, X, dist); return out; } template template inline urowvec gmm_diag::raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" ); arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" ); urowvec hist; internal_raw_hist(hist, X, dist_mode); return hist; } template template inline Row gmm_diag::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" ); arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" ); urowvec hist; internal_raw_hist(hist, X, dist_mode); const uword hist_n_elem = hist.n_elem; const uword* hist_mem = hist.memptr(); eT acc = eT(0); for(uword i=0; i out(hist_n_elem, arma_nozeros_indicator()); eT* out_mem = out.memptr(); for(uword i=0; i template inline bool gmm_diag::learn ( const Base& data, const uword N_gaus, const gmm_dist_mode& dist_mode, const gmm_seed_mode& seed_mode, const uword km_iter, const uword em_iter, const eT var_floor, const bool print_mode ) { arma_extra_debug_sigprint(); const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist); const bool seed_mode_ok = \ (seed_mode == keep_existing) || (seed_mode == static_subset) || (seed_mode == static_spread) || (seed_mode == random_subset) || (seed_mode == random_spread); arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" ); arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode" ); arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance floor is negative" ); const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix is empty" ); return false; } if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } if(dist_mode == maha_dist) { mah_aux = var(X,1,1); const uword mah_aux_n_elem = mah_aux.n_elem; eT* mah_aux_mem = mah_aux.memptr(); for(uword i=0; i < mah_aux_n_elem; ++i) { const eT val = mah_aux_mem[i]; mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1); } } // copy current model, in case of failure by k-means and/or EM const gmm_diag orig = (*this); // initial means if(seed_mode == keep_existing) { if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): no existing means" ); return false; } if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_diag::learn(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } reset(X.n_rows, N_gaus); if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); } if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); } else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); } } // k-means if(km_iter > 0) { const arma_ostream_state stream_state(get_cout_stream()); bool status = false; if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); } else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); } stream_state.restore(get_cout_stream()); if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } } // initial dcovs const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits::min(); if(seed_mode != keep_existing) { if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); } if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); } else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); } } // EM algorithm if(em_iter > 0) { const arma_ostream_state stream_state(get_cout_stream()); const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode); stream_state.restore(get_cout_stream()); if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } } mah_aux.reset(); init_constants(); return true; } template template inline bool gmm_diag::kmeans_wrapper ( Mat& user_means, const Base& data, const uword N_gaus, const gmm_seed_mode& seed_mode, const uword km_iter, const bool print_mode ) { arma_extra_debug_sigprint(); const bool seed_mode_ok = \ (seed_mode == keep_existing) || (seed_mode == static_subset) || (seed_mode == static_spread) || (seed_mode == random_subset) || (seed_mode == random_spread); arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" ); const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; if(X.is_empty() ) { arma_debug_warn_level(3, "kmeans(): given matrix is empty" ); return false; } if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "kmeans(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } // initial means if(seed_mode == keep_existing) { access::rw(means) = user_means; if(means.is_empty() ) { arma_debug_warn_level(3, "kmeans(): no existing means" ); return false; } if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "kmeans(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "kmeans(): number of vectors is less than number of means"); return false; } access::rw(means).zeros(X.n_rows, N_gaus); if(print_mode) { get_cout_stream() << "kmeans(): generating initial means\n"; } generate_initial_means<1>(X, seed_mode); } // k-means if(km_iter > 0) { const arma_ostream_state stream_state(get_cout_stream()); bool status = false; status = km_iterate<1>(X, km_iter, print_mode, "kmeans()"); stream_state.restore(get_cout_stream()); if(status == false) { arma_debug_warn_level(3, "kmeans(): clustering failed; not enough data, or too many means requested"); return false; } } return true; } // // // template inline void gmm_diag::init(const gmm_diag& x) { arma_extra_debug_sigprint(); gmm_diag& t = *this; if(&t != &x) { access::rw(t.means) = x.means; access::rw(t.dcovs) = x.dcovs; access::rw(t.hefts) = x.hefts; init_constants(); } } template inline void gmm_diag::init(const gmm_full& x) { arma_extra_debug_sigprint(); access::rw(hefts) = x.hefts; access::rw(means) = x.means; const uword N_dims = x.means.n_rows; const uword N_gaus = x.means.n_cols; access::rw(dcovs).zeros(N_dims,N_gaus); for(uword g=0; g < N_gaus; ++g) { const Mat& fcov = x.fcovs.slice(g); eT* dcov_mem = access::rw(dcovs).colptr(g); for(uword d=0; d < N_dims; ++d) { dcov_mem[d] = fcov.at(d,d); } } init_constants(); } template inline void gmm_diag::init(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint(); access::rw(means).zeros(in_n_dims, in_n_gaus); access::rw(dcovs).ones(in_n_dims, in_n_gaus); access::rw(hefts).set_size(in_n_gaus); access::rw(hefts).fill(eT(1) / eT(in_n_gaus)); init_constants(); } template inline void gmm_diag::init_constants() { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; // inv_dcovs.copy_size(dcovs); const eT* dcovs_mem = dcovs.memptr(); eT* inv_dcovs_mem = inv_dcovs.memptr(); const uword dcovs_n_elem = dcovs.n_elem; for(uword i=0; i < dcovs_n_elem; ++i) { inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits::min() ); } // const eT tmp = (eT(N_dims)/eT(2)) * std::log(Datum::tau); log_det_etc.set_size(N_gaus); for(uword g=0; g < N_gaus; ++g) { const eT* dcovs_colmem = dcovs.colptr(g); eT log_det_val = eT(0); for(uword d=0; d < N_dims; ++d) { log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits::min() ) ); } log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val ); } // eT* hefts_mem = access::rw(hefts).memptr(); for(uword g=0; g < N_gaus; ++g) { hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits::min() ); } log_hefts = log(hefts); } template inline umat gmm_diag::internal_gen_boundaries(const uword N) const { arma_extra_debug_sigprint(); #if defined(ARMA_USE_OPENMP) const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads()); const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1; #else static constexpr uword n_threads = 1; #endif // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; umat boundaries(2, n_threads, arma_nozeros_indicator()); if(N > 0) { const uword chunk_size = N / n_threads; uword count = 0; for(uword t=0; t inline eT gmm_diag::internal_scalar_log_p(const eT* x) const { arma_extra_debug_sigprint(); const eT* log_hefts_mem = log_hefts.mem; const uword N_gaus = means.n_cols; if(N_gaus > 0) { eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; for(uword g=1; g < N_gaus; ++g) { const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; log_sum = log_add_exp(log_sum, tmp); } return log_sum; } else { return -Datum::inf; } } template inline eT gmm_diag::internal_scalar_log_p(const eT* x, const uword g) const { arma_extra_debug_sigprint(); const eT* mean = means.colptr(g); const eT* inv_dcov = inv_dcovs.colptr(g); const uword N_dims = means.n_rows; eT val_i = eT(0); eT val_j = eT(0); uword i,j; for(i=0, j=1; j inline Row gmm_diag::internal_vec_log_p(const Mat& X) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); const uword N = X.n_cols; Row out(N, arma_nozeros_indicator()); if(N > 0) { #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT* out_mem = out.memptr(); for(uword i=start_index; i <= end_index; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i) ); } } } #else { eT* out_mem = out.memptr(); for(uword i=0; i < N; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i) ); } } #endif } return out; } template inline Row gmm_diag::internal_vec_log_p(const Mat& X, const uword gaus_id) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" ); const uword N = X.n_cols; Row out(N, arma_nozeros_indicator()); if(N > 0) { #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT* out_mem = out.memptr(); for(uword i=start_index; i <= end_index; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); } } } #else { eT* out_mem = out.memptr(); for(uword i=0; i < N; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); } } #endif } return out; } template inline eT gmm_diag::internal_sum_log_p(const Mat& X) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT t_acc = eT(0); for(uword i=start_index; i <= end_index; ++i) { t_acc += internal_scalar_log_p( X.colptr(i) ); } t_accs[t] = t_acc; } return eT(accu(t_accs)); } #else { eT acc = eT(0); for(uword i=0; i inline eT gmm_diag::internal_sum_log_p(const Mat& X, const uword gaus_id) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT t_acc = eT(0); for(uword i=start_index; i <= end_index; ++i) { t_acc += internal_scalar_log_p( X.colptr(i), gaus_id ); } t_accs[t] = t_acc; } return eT(accu(t_accs)); } #else { eT acc = eT(0); for(uword i=0; i inline eT gmm_diag::internal_avg_log_p(const Mat& X) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; field< running_mean_scalar > t_running_means(n_threads); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); running_mean_scalar& current_running_mean = t_running_means[t]; for(uword i=start_index; i <= end_index; ++i) { current_running_mean( internal_scalar_log_p( X.colptr(i) ) ); } } eT avg = eT(0); for(uword t=0; t < n_threads; ++t) { running_mean_scalar& current_running_mean = t_running_means[t]; const eT w = eT(current_running_mean.count()) / eT(N); avg += w * current_running_mean.mean(); } return avg; } #else { running_mean_scalar running_mean; for(uword i=0; i inline eT gmm_diag::internal_avg_log_p(const Mat& X, const uword gaus_id) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; field< running_mean_scalar > t_running_means(n_threads); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); running_mean_scalar& current_running_mean = t_running_means[t]; for(uword i=start_index; i <= end_index; ++i) { current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) ); } } eT avg = eT(0); for(uword t=0; t < n_threads; ++t) { running_mean_scalar& current_running_mean = t_running_means[t]; const eT w = eT(current_running_mean.count()) / eT(N); avg += w * current_running_mean.mean(); } return avg; } #else { running_mean_scalar running_mean; for(uword i=0; i inline uword gmm_diag::internal_scalar_assign(const Mat& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no means" ); const eT* X_mem = X.colptr(0); if(dist_mode == eucl_dist) { eT best_dist = Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_dist = distance::eval(N_dims, X_mem, means.colptr(g), X_mem); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } return best_g; } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); eT best_p = -Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } } return best_g; } else { arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); } return uword(0); } template inline void gmm_diag::internal_vec_assign(urowvec& out, const Mat& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0; out.set_size(1,X_n_cols); uword* out_mem = out.memptr(); if(dist_mode == eucl_dist) { #if defined(ARMA_USE_OPENMP) { #pragma omp parallel for schedule(static) for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } out_mem[i] = best_g; } } #else { for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } out_mem[i] = best_g; } } #endif } else if(dist_mode == prob_dist) { #if defined(ARMA_USE_OPENMP) { const eT* log_hefts_mem = log_hefts.memptr(); #pragma omp parallel for schedule(static) for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } } out_mem[i] = best_g; } } #else { const eT* log_hefts_mem = log_hefts.memptr(); for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } } out_mem[i] = best_g; } } #endif } else { arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); } } template inline void gmm_diag::internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const uword X_n_cols = X.n_cols; hist.zeros(N_gaus); if(N_gaus == 0) { return; } #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(X_n_cols); const uword n_threads = boundaries.n_cols; field thread_hist(n_threads); for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); } if(dist_mode == eucl_dist) { #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { uword* thread_hist_mem = thread_hist(t).memptr(); const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); for(uword i=start_index; i <= end_index; ++i) { const eT* X_colptr = X.colptr(i); eT best_dist = Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } thread_hist_mem[best_g]++; } } } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { uword* thread_hist_mem = thread_hist(t).memptr(); const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); for(uword i=start_index; i <= end_index; ++i) { const eT* X_colptr = X.colptr(i); eT best_p = -Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } } thread_hist_mem[best_g]++; } } } // reduction hist = thread_hist(0); for(uword t=1; t < n_threads; ++t) { hist += thread_hist(t); } } #else { uword* hist_mem = hist.memptr(); if(dist_mode == eucl_dist) { for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } hist_mem[best_g]++; } } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } } hist_mem[best_g]++; } } } #endif } template template inline void gmm_diag::generate_initial_means(const Mat& X, const gmm_seed_mode& seed_mode) { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; if( (seed_mode == static_subset) || (seed_mode == random_subset) ) { uvec initial_indices; if(seed_mode == static_subset) { initial_indices = linspace(0, X.n_cols-1, N_gaus); } else if(seed_mode == random_subset) { initial_indices = randperm(X.n_cols, N_gaus); } // initial_indices.print("initial_indices:"); access::rw(means) = X.cols(initial_indices); } else if( (seed_mode == static_spread) || (seed_mode == random_spread) ) { // going through all of the samples can be extremely time consuming; // instead, if there are enough samples, randomly choose samples with probability 0.1 const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus); const uword step = (use_sampling) ? uword(10) : uword(1); uword start_index = 0; if(seed_mode == static_spread) { start_index = X.n_cols / 2; } else if(seed_mode == random_spread) { start_index = as_scalar(randi(1, distr_param(0,X.n_cols-1))); } access::rw(means).col(0) = X.unsafe_col(start_index); const eT* mah_aux_mem = mah_aux.memptr(); running_stat rs; for(uword g=1; g < N_gaus; ++g) { eT max_dist = eT(0); uword best_i = uword(0); uword start_i = uword(0); if(use_sampling) { uword start_i_proposed = uword(0); if(seed_mode == static_spread) { start_i_proposed = g % uword(10); } if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi(1, distr_param(0,9))); } if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; } } for(uword i=start_i; i < X.n_cols; i += step) { rs.reset(); const eT* X_colptr = X.colptr(i); bool ignore_i = false; // find the average distance between sample i and the means so far for(uword h = 0; h < g; ++h) { const eT dist = distance::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem); // ignore sample already selected as a mean if(dist == eT(0)) { ignore_i = true; break; } else { rs(dist); } } if( (rs.mean() >= max_dist) && (ignore_i == false)) { max_dist = eT(rs.mean()); best_i = i; } } // set the mean to the sample that is the furthest away from the means so far access::rw(means).col(g) = X.unsafe_col(best_i); } } // get_cout_stream() << "generate_initial_means():" << '\n'; // means.print(); } template template inline void gmm_diag::generate_initial_params(const Mat& X, const eT var_floor) { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT* mah_aux_mem = mah_aux.memptr(); const uword X_n_cols = X.n_cols; if(X_n_cols == 0) { return; } // as the covariances are calculated via accumulators, // the means also need to be calculated via accumulators to ensure numerical consistency Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); Mat acc_dcovs(N_dims, N_gaus, arma_zeros_indicator()); Row acc_hefts(N_gaus, arma_zeros_indicator()); uword* acc_hefts_mem = acc_hefts.memptr(); #if defined(ARMA_USE_OPENMP) { const umat boundaries = internal_gen_boundaries(X_n_cols); const uword n_threads = boundaries.n_cols; field< Mat > t_acc_means(n_threads); field< Mat > t_acc_dcovs(n_threads); field< Row > t_acc_hefts(n_threads); for(uword t=0; t < n_threads; ++t) { t_acc_means(t).zeros(N_dims, N_gaus); t_acc_dcovs(t).zeros(N_dims, N_gaus); t_acc_hefts(t).zeros(N_gaus); } #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); for(uword i=start_index; i <= end_index; ++i) { const eT* X_colptr = X.colptr(i); eT min_dist = Datum::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); if(dist < min_dist) { min_dist = dist; best_g = g; } } eT* t_acc_mean = t_acc_means(t).colptr(best_g); eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g); for(uword d=0; d::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); if(dist < min_dist) { min_dist = dist; best_g = g; } } eT* acc_mean = acc_means.colptr(best_g); eT* acc_dcov = acc_dcovs.colptr(best_g); for(uword d=0; d= 1) ? tmp : eT(0); dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor); } hefts_mem[g] = eT(acc_heft) / eT(X_n_cols); } em_fix_params(var_floor); } //! multi-threaded implementation of k-means, inspired by MapReduce template template inline bool gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verbose, const char* signature) { arma_extra_debug_sigprint(); if(verbose) { get_cout_stream().unsetf(ios::showbase); get_cout_stream().unsetf(ios::uppercase); get_cout_stream().unsetf(ios::showpos); get_cout_stream().unsetf(ios::scientific); get_cout_stream().setf(ios::right); get_cout_stream().setf(ios::fixed); } const uword X_n_cols = X.n_cols; if(X_n_cols == 0) { return true; } const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT* mah_aux_mem = mah_aux.memptr(); Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); Row acc_hefts( N_gaus, arma_zeros_indicator()); Row last_indx( N_gaus, arma_zeros_indicator()); Mat new_means = means; Mat old_means = means; running_mean_scalar rs_delta; #if defined(ARMA_USE_OPENMP) const umat boundaries = internal_gen_boundaries(X_n_cols); const uword n_threads = boundaries.n_cols; field< Mat > t_acc_means(n_threads); field< Row > t_acc_hefts(n_threads); field< Row > t_last_indx(n_threads); #else const uword n_threads = 1; #endif if(verbose) { get_cout_stream() << signature << ": n_threads: " << n_threads << '\n'; get_cout_stream().flush(); } for(uword iter=1; iter <= max_iter; ++iter) { #if defined(ARMA_USE_OPENMP) { for(uword t=0; t < n_threads; ++t) { t_acc_means(t).zeros(N_dims, N_gaus); t_acc_hefts(t).zeros(N_gaus); t_last_indx(t).zeros(N_gaus); } #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) { Mat& t_acc_means_t = t_acc_means(t); uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); uword* t_last_indx_mem = t_last_indx(t).memptr(); const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); for(uword i=start_index; i <= end_index; ++i) { const eT* X_colptr = X.colptr(i); eT min_dist = Datum::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); if(dist < min_dist) { min_dist = dist; best_g = g; } } eT* t_acc_mean = t_acc_means_t.colptr(best_g); for(uword d=0; d= 1 ) { last_indx(g) = t_last_indx(t)(g); } } } #else { acc_hefts.zeros(); acc_means.zeros(); last_indx.zeros(); uword* acc_hefts_mem = acc_hefts.memptr(); uword* last_indx_mem = last_indx.memptr(); for(uword i=0; i < X_n_cols; ++i) { const eT* X_colptr = X.colptr(i); eT min_dist = Datum::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); if(dist < min_dist) { min_dist = dist; best_g = g; } } eT* acc_mean = acc_means.colptr(best_g); for(uword d=0; d= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0); } } // heuristics to resurrect dead means const uvec dead_gs = find(acc_hefts == uword(0)); if(dead_gs.n_elem > 0) { if(verbose) { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); } uword* last_indx_mem = last_indx.memptr(); const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" ); if(live_gs.n_elem == 0) { return false; } uword live_gs_count = 0; for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count) { const uword dead_g_id = dead_gs(dead_gs_count); uword proposed_i = 0; if(live_gs_count < live_gs.n_elem) { const uword live_g_id = live_gs(live_gs_count); ++live_gs_count; if(live_g_id == dead_g_id) { return false; } // recover by using a sample from a known good mean proposed_i = last_indx_mem[live_g_id]; } else { // recover by using a randomly seleced sample (last resort) proposed_i = as_scalar(randi(1, distr_param(0,X_n_cols-1))); } if(proposed_i >= X_n_cols) { return false; } new_means.col(dead_g_id) = X.col(proposed_i); } } rs_delta.reset(); for(uword g=0; g < N_gaus; ++g) { rs_delta( distance::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) ); } if(verbose) { get_cout_stream() << signature << ": iteration: "; get_cout_stream().unsetf(ios::scientific); get_cout_stream().setf(ios::fixed); get_cout_stream().width(std::streamsize(4)); get_cout_stream() << iter; get_cout_stream() << " delta: "; get_cout_stream().unsetf(ios::fixed); //get_cout_stream().setf(ios::scientific); get_cout_stream() << rs_delta.mean() << '\n'; get_cout_stream().flush(); } arma::swap(old_means, new_means); if(rs_delta.mean() <= Datum::eps) { break; } } access::rw(means) = old_means; if(means.internal_has_nonfinite()) { return false; } return true; } //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce template inline bool gmm_diag::em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose) { arma_extra_debug_sigprint(); if(X.n_cols == 0) { return true; } const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; if(verbose) { get_cout_stream().unsetf(ios::showbase); get_cout_stream().unsetf(ios::uppercase); get_cout_stream().unsetf(ios::showpos); get_cout_stream().unsetf(ios::scientific); get_cout_stream().setf(ios::right); get_cout_stream().setf(ios::fixed); } const umat boundaries = internal_gen_boundaries(X.n_cols); const uword n_threads = boundaries.n_cols; field< Mat > t_acc_means(n_threads); field< Mat > t_acc_dcovs(n_threads); field< Col > t_acc_norm_lhoods(n_threads); field< Col > t_gaus_log_lhoods(n_threads); Col t_progress_log_lhood(n_threads, arma_nozeros_indicator()); for(uword t=0; t::inf; for(uword iter=1; iter <= max_iter; ++iter) { init_constants(); em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood); em_fix_params(var_floor); const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem); if(verbose) { get_cout_stream() << "gmm_diag::learn(): EM: iteration: "; get_cout_stream().unsetf(ios::scientific); get_cout_stream().setf(ios::fixed); get_cout_stream().width(std::streamsize(4)); get_cout_stream() << iter; get_cout_stream() << " avg_log_p: "; get_cout_stream().unsetf(ios::fixed); //get_cout_stream().setf(ios::scientific); get_cout_stream() << new_avg_log_p << '\n'; get_cout_stream().flush(); } if(arma_isfinite(new_avg_log_p) == false) { return false; } if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum::eps) { break; } old_avg_log_p = new_avg_log_p; } if(any(vectorise(dcovs) <= eT(0))) { return false; } if(means.internal_has_nonfinite()) { return false; } if(dcovs.internal_has_nonfinite()) { return false; } if(hefts.internal_has_nonfinite()) { return false; } return true; } template inline void gmm_diag::em_update_params ( const Mat& X, const umat& boundaries, field< Mat >& t_acc_means, field< Mat >& t_acc_dcovs, field< Col >& t_acc_norm_lhoods, field< Col >& t_gaus_log_lhoods, Col& t_progress_log_lhood ) { arma_extra_debug_sigprint(); const uword n_threads = boundaries.n_cols; // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts #if defined(ARMA_USE_OPENMP) { #pragma omp parallel for schedule(static) for(uword t=0; t& acc_means = t_acc_means[t]; Mat& acc_dcovs = t_acc_dcovs[t]; Col& acc_norm_lhoods = t_acc_norm_lhoods[t]; Col& gaus_log_lhoods = t_gaus_log_lhoods[t]; eT& progress_log_lhood = t_progress_log_lhood[t]; em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood); } } #else { em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); } #endif const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Mat& final_acc_means = t_acc_means[0]; Mat& final_acc_dcovs = t_acc_dcovs[0]; Col& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; // the "reduce" operation, which combines the partial accumulators produced by the separate threads for(uword t=1; t::min() ); // // eT* mean_mem = access::rw(means).colptr(g); // eT* dcov_mem = access::rw(dcovs).colptr(g); // // eT* acc_mean_mem = final_acc_means.colptr(g); // eT* acc_dcov_mem = final_acc_dcovs.colptr(g); // // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); // // for(uword d=0; d < N_dims; ++d) // { // const eT tmp = acc_mean_mem[d] / acc_norm_lhood; // // mean_mem[d] = tmp; // dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp; // } // } // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them for(uword g=0; g < N_gaus; ++g) { const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits::min() ); if(arma_isfinite(acc_norm_lhood) == false) { continue; } eT* acc_mean_mem = final_acc_means.colptr(g); eT* acc_dcov_mem = final_acc_dcovs.colptr(g); bool ok = true; for(uword d=0; d < N_dims; ++d) { const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood; const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1; acc_mean_mem[d] = tmp1; acc_dcov_mem[d] = tmp2; if(arma_isfinite(tmp2) == false) { ok = false; } } if(ok) { hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); eT* mean_mem = access::rw(means).colptr(g); eT* dcov_mem = access::rw(dcovs).colptr(g); for(uword d=0; d < N_dims; ++d) { mean_mem[d] = acc_mean_mem[d]; dcov_mem[d] = acc_dcov_mem[d]; } } } } template inline void gmm_diag::em_generate_acc ( const Mat& X, const uword start_index, const uword end_index, Mat& acc_means, Mat& acc_dcovs, Col& acc_norm_lhoods, Col& gaus_log_lhoods, eT& progress_log_lhood ) const { arma_extra_debug_sigprint(); progress_log_lhood = eT(0); acc_means.zeros(); acc_dcovs.zeros(); acc_norm_lhoods.zeros(); gaus_log_lhoods.zeros(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT* log_hefts_mem = log_hefts.memptr(); eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); for(uword i=start_index; i <= end_index; i++) { const eT* x = X.colptr(i); for(uword g=0; g < N_gaus; ++g) { gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g]; } eT log_lhood_sum = gaus_log_lhoods_mem[0]; for(uword g=1; g < N_gaus; ++g) { log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); } progress_log_lhood += log_lhood_sum; for(uword g=0; g < N_gaus; ++g) { const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum); acc_norm_lhoods[g] += norm_lhood; eT* acc_mean_mem = acc_means.colptr(g); eT* acc_dcov_mem = acc_dcovs.colptr(g); for(uword d=0; d < N_dims; ++d) { const eT x_d = x[d]; const eT y_d = x_d * norm_lhood; acc_mean_mem[d] += y_d; acc_dcov_mem[d] += y_d * x_d; // equivalent to x_d * x_d * norm_lhood } } } progress_log_lhood /= eT((end_index - start_index) + 1); } template inline void gmm_diag::em_fix_params(const eT var_floor) { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT var_ceiling = std::numeric_limits::max(); const uword dcovs_n_elem = dcovs.n_elem; eT* dcovs_mem = access::rw(dcovs).memptr(); for(uword i=0; i < dcovs_n_elem; ++i) { eT& var_val = dcovs_mem[i]; if(var_val < var_floor ) { var_val = var_floor; } else if(var_val > var_ceiling) { var_val = var_ceiling; } else if(arma_isnan(var_val) ) { var_val = eT(1); } } eT* hefts_mem = access::rw(hefts).memptr(); for(uword g1=0; g1 < N_gaus; ++g1) { if(hefts_mem[g1] > eT(0)) { const eT* means_colptr_g1 = means.colptr(g1); for(uword g2=(g1+1); g2 < N_gaus; ++g2) { if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits::epsilon()) ) { const eT dist = distance::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1); if(dist == eT(0)) { hefts_mem[g2] = eT(0); } } } } } const eT heft_floor = std::numeric_limits::min(); const eT heft_initial = eT(1) / eT(N_gaus); for(uword i=0; i < N_gaus; ++i) { eT& heft_val = hefts_mem[i]; if(heft_val < heft_floor) { heft_val = heft_floor; } else if(heft_val > eT(1) ) { heft_val = eT(1); } else if(arma_isnan(heft_val) ) { heft_val = heft_initial; } } const eT heft_sum = accu(hefts); if((heft_sum < (eT(1) - Datum::eps)) || (heft_sum > (eT(1) + Datum::eps))) { access::rw(hefts) /= heft_sum; } } } // namespace gmm_priv //! @}