From eda5bc26f44ee9a6f83dcf8c91f17296d7fc509d Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 12 Feb 2024 14:52:43 +0100 Subject: Move into version control --- src/armadillo/include/armadillo_bits/fn_misc.hpp | 587 +++++++++++++++++++++++ 1 file changed, 587 insertions(+) create mode 100644 src/armadillo/include/armadillo_bits/fn_misc.hpp (limited to 'src/armadillo/include/armadillo_bits/fn_misc.hpp') diff --git a/src/armadillo/include/armadillo_bits/fn_misc.hpp b/src/armadillo/include/armadillo_bits/fn_misc.hpp new file mode 100644 index 0000000..51930e4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_misc.hpp @@ -0,0 +1,587 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_misc +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_Mat::value, + out_type + >::result +linspace + ( + const typename out_type::pod_type start, + const typename out_type::pod_type end, + const uword num = 100u + ) + { + arma_extra_debug_sigprint(); + + typedef typename out_type::elem_type eT; + typedef typename out_type::pod_type T; + + out_type x; + + if(num == 1) + { + x.set_size(1); + + x[0] = eT(end); + } + else + if(num >= 2) + { + x.set_size(num); + + eT* x_mem = x.memptr(); + + const uword num_m1 = num - 1; + + if(is_non_integral::value) + { + const T delta = (end-start)/T(num_m1); + + for(uword i=0; i= start) ? double(end-start)/double(num_m1) : -double(start-end)/double(num_m1); + + for(uword i=0; i(start, end, num); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_Mat::value && is_real::value), + out_type + >::result +logspace + ( + const typename out_type::pod_type A, + const typename out_type::pod_type B, + const uword N = 50u + ) + { + arma_extra_debug_sigprint(); + + typedef typename out_type::elem_type eT; + typedef typename out_type::pod_type T; + + out_type x = linspace(A,B,N); + + const uword n_elem = x.n_elem; + + eT* x_mem = x.memptr(); + + for(uword i=0; i < n_elem; ++i) + { + x_mem[i] = std::pow(T(10), x_mem[i]); + } + + return x; + } + + + +arma_warn_unused +inline +vec +logspace(const double A, const double B, const uword N = 50u) + { + arma_extra_debug_sigprint(); + return logspace(A, B, N); + } + + + +// +// log_exp_add + +template +arma_warn_unused +inline +typename arma_real_only::result +log_add_exp(eT log_a, eT log_b) + { + if(log_a < log_b) + { + std::swap(log_a, log_b); + } + + const eT negdelta = log_b - log_a; + + if( (negdelta < Datum::log_min) || (arma_isfinite(negdelta) == false) ) + { + return log_a; + } + else + { + return (log_a + std::log1p(std::exp(negdelta))); + } + } + + + +// for compatibility with earlier versions +template +arma_warn_unused +inline +typename arma_real_only::result +log_add(eT log_a, eT log_b) + { + return log_add_exp(log_a, log_b); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +arma_inline +bool +is_finite(const eT x, const typename arma_scalar_only::result* junk = nullptr) + { + arma_ignore(junk); + + return arma_isfinite(x); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const Base& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +template +inline +void +swap(Mat& A, Mat& B) + { + arma_extra_debug_sigprint(); + + A.swap(B); + } + + + +template +inline +void +swap(Cube& A, Cube& B) + { + arma_extra_debug_sigprint(); + + A.swap(B); + } + + + +arma_warn_unused +inline +uvec +ind2sub(const SizeMat& s, const uword i) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + + arma_debug_check( (i >= (s_n_rows * s.n_cols) ), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uvec out(2, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + + out_mem[0] = row; + out_mem[1] = col; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_same_type::yes), umat >::result +ind2sub(const SizeMat& s, const T1& indices) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem = s_n_rows * s.n_cols; + + const Proxy P(indices); + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + const uword P_n_elem = P.get_n_elem(); + + const bool P_is_empty = (P_n_elem == 0); + const bool P_is_vec = ((P_n_rows == 1) || (P_n_cols == 1)); + + arma_debug_check( ((P_is_empty == false) && (P_is_vec == false)), "ind2sub(): parameter 'indices' must be a vector" ); + + umat out(2, P_n_elem, arma_nozeros_indicator()); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword count=0; count < P_n_elem; ++count) + { + const uword i = Pea[count]; + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + else + { + if(P_n_rows == 1) + { + for(uword count=0; count < P_n_cols; ++count) + { + const uword i = P.at(0,count); + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + else + if(P_n_cols == 1) + { + for(uword count=0; count < P_n_rows; ++count) + { + const uword i = P.at(count,0); + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + } + + return out; + } + + + +arma_warn_unused +inline +uvec +ind2sub(const SizeCube& s, const uword i) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem_slice = s_n_rows * s.n_cols; + + arma_debug_check( (i >= (s_n_elem_slice * s.n_slices) ), "ind2sub(): index out of range" ); + + const uword slice = i / s_n_elem_slice; + const uword j = i - (slice * s_n_elem_slice); + const uword row = j % s_n_rows; + const uword col = j / s_n_rows; + + uvec out(3, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + + out_mem[0] = row; + out_mem[1] = col; + out_mem[2] = slice; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_same_type::yes), umat >::result +ind2sub(const SizeCube& s, const T1& indices) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem_slice = s_n_rows * s.n_cols; + const uword s_n_elem = s.n_slices * s_n_elem_slice; + + const quasi_unwrap U(indices); + + arma_debug_check( ((U.M.is_empty() == false) && (U.M.is_vec() == false)), "ind2sub(): parameter 'indices' must be a vector" ); + + const uword U_n_elem = U.M.n_elem; + const uword* U_mem = U.M.memptr(); + + umat out(3, U_n_elem, arma_nozeros_indicator()); + + for(uword count=0; count < U_n_elem; ++count) + { + const uword i = U_mem[count]; + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword slice = i / s_n_elem_slice; + const uword j = i - (slice * s_n_elem_slice); + const uword row = j % s_n_rows; + const uword col = j / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + out_colptr[2] = slice; + } + + return out; + } + + + +arma_warn_unused +arma_inline +uword +sub2ind(const SizeMat& s, const uword row, const uword col) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + + arma_debug_check( ((row >= s_n_rows) || (col >= s.n_cols)), "sub2ind(): subscript out of range" ); + + return uword(row + col*s_n_rows); + } + + + +template +arma_warn_unused +inline +uvec +sub2ind(const SizeMat& s, const Base& subscripts) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + const quasi_unwrap U(subscripts.get_ref()); + + arma_debug_check( (U.M.n_rows != 2), "sub2ind(): matrix of subscripts must have 2 rows" ); + + const uword U_M_n_cols = U.M.n_cols; + + uvec out(U_M_n_cols, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + const uword* U_M_mem = U.M.memptr(); + + for(uword count=0; count < U_M_n_cols; ++count) + { + const uword row = U_M_mem[0]; + const uword col = U_M_mem[1]; + + U_M_mem += 2; // next column + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols)), "sub2ind(): subscript out of range" ); + + out_mem[count] = uword(row + col*s_n_rows); + } + + return out; + } + + + +arma_warn_unused +arma_inline +uword +sub2ind(const SizeCube& s, const uword row, const uword col, const uword slice) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s.n_slices)), "sub2ind(): subscript out of range" ); + + return uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row ); + } + + + +template +arma_warn_unused +inline +uvec +sub2ind(const SizeCube& s, const Base& subscripts) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword s_n_slices = s.n_slices; + + const quasi_unwrap U(subscripts.get_ref()); + + arma_debug_check( (U.M.n_rows != 3), "sub2ind(): matrix of subscripts must have 3 rows" ); + + const uword U_M_n_cols = U.M.n_cols; + + uvec out(U_M_n_cols, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + const uword* U_M_mem = U.M.memptr(); + + for(uword count=0; count < U_M_n_cols; ++count) + { + const uword row = U_M_mem[0]; + const uword col = U_M_mem[1]; + const uword slice = U_M_mem[2]; + + U_M_mem += 3; // next column + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s_n_slices)), "sub2ind(): subscript out of range" ); + + out_mem[count] = uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row ); + } + + return out; + } + + + +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_same_type::value), + const Glue + >::result +affmul(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A,B); + } + + + +//! @} -- cgit v1.2.1